#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
将本地图片上传到 OSS 的 Lance Table 中
"""

import lancedb
import pyarrow as pa
from pathlib import Path
import oss2
import os
import shutil

# ============ 配置 ============
LOCAL_IMAGE_DIR = r"D:\Downloads\image_demo"
OSS_ACCESS_KEY_ID = os.environ.get('OSS_ACCESS_KEY_ID', 'your_ak_id')
OSS_ACCESS_KEY_SECRET = os.environ.get('OSS_ACCESS_KEY_SECRET', 'your_ak_secret')
OSS_ENDPOINT = 'oss-cn-beijing.aliyuncs.com'
OSS_BUCKET_NAME = 'lancetable-demo'
OSS_LANCE_DIR = 'lance_table/image/'
TABLE_NAME = 'image'

# 支持的图片格式
SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.tiff', '.tif')


def get_image_files(directory: str):
    """获取目录中的所有图片文件"""
    image_dir = Path(directory)
    if not image_dir.exists():
        raise FileNotFoundError(f"目录不存在: {directory}")

    image_files = []
    for ext in SUPPORTED_FORMATS:
        image_files.extend(image_dir.glob(f"*{ext}"))
        image_files.extend(image_dir.glob(f"*{ext.upper()}"))

    # 去重并排序
    image_files = sorted(set(image_files))
    return image_files


def read_image_binary(image_path: Path) -> bytes:
    """读取图片为二进制数据"""
    with open(image_path, 'rb') as f:
        return f.read()


def upload_lance_db_to_oss(local_db_path: str, bucket, oss_prefix: str):
    """将本地 LanceDB 上传到 OSS"""
    print(f"\n开始上传到 OSS: {oss_prefix}")

    for root, dirs, files in os.walk(local_db_path):
        for file in files:
            local_file = os.path.join(root, file)
            relative_path = os.path.relpath(local_file, local_db_path)
            oss_key = f"{oss_prefix}{relative_path}".replace("\\", "/")

            bucket.put_object_from_file(oss_key, local_file)
            print(f"  上传: {oss_key}")

    print("✓ OSS 上传完成")


def main():
    print("=" * 60)
    print("图片上传到 OSS Lance Table")
    print("=" * 60)

    # 1. 获取本地图片
    print(f"\n扫描目录: {LOCAL_IMAGE_DIR}")
    image_files = get_image_files(LOCAL_IMAGE_DIR)
    print(f"找到 {len(image_files)} 张图片")

    if len(image_files) == 0:
        print("没有找到图片文件，退出")
        return

    # 2. 读取图片数据
    print("\n读取图片数据...")
    image_data_list = []

    for idx, image_path in enumerate(image_files, start=1):
        try:
            binary_data = read_image_binary(image_path)
            image_data_list.append({
                'image_id': idx,                    # 自增 ID
                'image_name': image_path.name,       # 图片文件名
                'image': binary_data                 # binary 字段
            })
            print(f"  [{idx}] {image_path.name} ({len(binary_data)} bytes)")
        except Exception as e:
            print(f"  错误: 无法读取 {image_path.name} - {e}")

    print(f"\n成功读取 {len(image_data_list)} 张图片")

    # 3. 创建 PyArrow Table
    print("\n创建数据表...")
    arrow_table = pa.table({
        'image_id': [d['image_id'] for d in image_data_list],
        'image_name': [d['image_name'] for d in image_data_list],
        'image': [d['image'] for d in image_data_list]
    })

    print(f"表结构:\n{arrow_table.schema}")
    print(f"记录数: {len(arrow_table)}")

    # 4. 创建本地 LanceDB
    local_db_path = './temp_lance_db'

    # 清理旧的临时目录
    if os.path.exists(local_db_path):
        shutil.rmtree(local_db_path)

    print(f"\n创建本地 LanceDB: {local_db_path}")
    db = lancedb.connect(local_db_path)

    # 创建表
    table = db.create_table(
        TABLE_NAME,
        data=arrow_table,
        mode='overwrite'
    )
    print(f"✓ 表 '{TABLE_NAME}' 创建成功")

    # 5. 上传到 OSS
    print("\n连接 OSS...")
    auth = oss2.Auth(OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET)
    bucket = oss2.Bucket(auth, OSS_ENDPOINT, OSS_BUCKET_NAME)

    # 检查 bucket 是否存在
    try:
        bucket.get_bucket_info()
        print(f"✓ 连接到 Bucket: {OSS_BUCKET_NAME}")
    except oss2.exceptions.NoSuchBucket:
        print(f"错误: Bucket '{OSS_BUCKET_NAME}' 不存在")
        return
    except Exception as e:
        print(f"OSS 连接错误: {e}")
        return

    # 上传 LanceDB 文件到 OSS
    upload_lance_db_to_oss(local_db_path, bucket, OSS_LANCE_DIR)

    # 6. 验证上传
    print("\n验证 OSS 上的数据...")
    objects = list(oss2.ObjectIterator(bucket, prefix=OSS_LANCE_DIR))
    print(f"OSS 上共有 {len(objects)} 个对象")

    # 列出部分对象
    for obj in objects[:5]:
        print(f"  - {obj.key} ({obj.size} bytes)")
    if len(objects) > 5:
        print(f"  ... 还有 {len(objects) - 5} 个对象")

    # 7. 清理临时文件
    print("\n清理临时文件...")
    shutil.rmtree(local_db_path, ignore_errors=True)

    print("\n" + "=" * 60)
    print("✓ 完成!")
    print(f"  图片数量: {len(image_data_list)}")
    print(f"  OSS 路径: oss://{OSS_BUCKET_NAME}/{OSS_LANCE_DIR}{TABLE_NAME}")
    print("=" * 60)


if __name__ == '__main__':
    main()
