# -*- coding: utf-8 -*-
"""
AgentLoop Dataset 快速入门（端到端自闭环）

完整演示 Dataset 全生命周期：
  创建 Dataset → 查看详情 → 列举 Dataset → 写入数据 → 查询数据
  → 更新数据 → 删除数据 → 删除 Dataset

可直接运行，无外部依赖数据。

使用前请设置环境变量：
- SLSDEMO_ALIYUN_ACCESS_KEY_ID
- SLSDEMO_ALIYUN_ACCESS_KEY_SECRET
- ALIBABA_CLOUD_CMS_ENDPOINT   (如: cms.cn-shanghai.aliyuncs.com)
- ALIBABA_CLOUD_CMS_WORKSPACE
"""

import json
import os
import time

from alibabacloud_cms20240330.client import Client
from alibabacloud_cms20240330.models import (
    CreateDatasetRequest,
    ExecuteQueryRequest,
    IndexJsonKey,
    IndexKey,
    ListDatasetsRequest,
    UpdateDatasetRequest,
)
from alibabacloud_tea_openapi.models import Config
from dotenv import load_dotenv

load_dotenv()

DATASET_NAME = "quickstart_demo"


# ---------------------------------------------------------------------------
# 客户端初始化
# ---------------------------------------------------------------------------

def get_client() -> Client:
    config = Config(
        access_key_id=os.getenv("SLSDEMO_ALIYUN_ACCESS_KEY_ID"),
        access_key_secret=os.getenv("SLSDEMO_ALIYUN_ACCESS_KEY_SECRET"),
        endpoint=os.getenv("ALIBABA_CLOUD_CMS_ENDPOINT"),
    )
    return Client(config)


def execute_query(client: Client, workspace: str, query: str) -> dict:
    """统一的查询 / 写入入口（type 固定为 SQL）"""
    request = ExecuteQueryRequest(query=query, type="SQL")
    response = client.execute_query(workspace, DATASET_NAME, request)
    return response.body.to_map()


# ---------------------------------------------------------------------------
# Step 1: 创建 Dataset
# ---------------------------------------------------------------------------

def step_create_dataset(client: Client, workspace: str):
    print("\n" + "=" * 60)
    print("Step 1: 创建 Dataset")
    print("=" * 60)

    schema = {
        "input": IndexKey(type="text", chn=True, embedding="text-embedding-v4"),
        "output": IndexKey(type="text", chn=True, embedding="text-embedding-v4"),
        "model": IndexKey(type="text"),
        "score": IndexKey(type="double"),
        "metadata": IndexKey(
            type="json",
            json_keys={
                "input_tokens": IndexJsonKey(type="long"),
                "output_tokens": IndexJsonKey(type="long"),
            },
        ),
    }

    request = CreateDatasetRequest(
        dataset_name=DATASET_NAME,
        description="快速入门示例数据集",
        schema=schema,
    )
    result = client.create_dataset(workspace, request)
    print(f"✓ Dataset '{DATASET_NAME}' 创建成功 (requestId: {result.body.request_id})")


# ---------------------------------------------------------------------------
# Step 2: 查看 Dataset 详情
# ---------------------------------------------------------------------------

def step_get_dataset(client: Client, workspace: str):
    print("\n" + "=" * 60)
    print("Step 2: 查看 Dataset 详情")
    print("=" * 60)

    response = client.get_dataset(workspace, DATASET_NAME)
    body = response.body.to_map()
    print(json.dumps(body, indent=2, ensure_ascii=False))


# ---------------------------------------------------------------------------
# Step 3: 列举所有 Dataset
# ---------------------------------------------------------------------------

def step_list_datasets(client: Client, workspace: str):
    print("\n" + "=" * 60)
    print("Step 3: 列举所有 Dataset")
    print("=" * 60)

    request = ListDatasetsRequest(max_results=100)
    response = client.list_datasets(workspace, request)
    body = response.body.to_map()

    datasets = body.get("datasets", [])
    print(f"当前 workspace 共 {len(datasets)} 个 Dataset:")
    for ds in datasets:
        marker = " ← 刚创建" if ds["datasetName"] == DATASET_NAME else ""
        print(f"  - {ds['datasetName']}: {ds.get('description', '')}{marker}")


# ---------------------------------------------------------------------------
# Step 4: 写入数据（批量 INSERT）
# ---------------------------------------------------------------------------

SAMPLE_DATA = [
    ("如何查看最近一小时的错误日志？",
     "使用查询语句: level:ERROR | SELECT * FROM log WHERE __time__ > now() - interval ''1 hour''",
     "qwen-plus", 0.95, '{"input_tokens": 15, "output_tokens": 42}'),
    ("统计今天各接口的调用次数",
     "SELECT api_path, count(*) as cnt FROM log GROUP BY api_path ORDER BY cnt DESC",
     "gpt-4o", 0.88, '{"input_tokens": 12, "output_tokens": 35}'),
    ("查找响应时间超过5秒的请求",
     "latency > 5000 | SELECT request_id, latency, api_path FROM log",
     "qwen-plus", 0.92, '{"input_tokens": 10, "output_tokens": 28}'),
    ("分析用户登录失败的原因分布",
     "action:login AND status:failed | SELECT error_reason, count(*) as cnt FROM log GROUP BY error_reason",
     "claude-3.5-sonnet", 0.85, '{"input_tokens": 14, "output_tokens": 45}'),
    ("查询日志中的异常错误堆栈信息",
     "level:ERROR | SELECT message, stack_trace FROM log ORDER BY __time__ DESC LIMIT 100",
     "qwen-plus", 0.90, '{"input_tokens": 11, "output_tokens": 38}'),
    ("数据库慢查询时间分布统计",
     "SELECT date_trunc(''hour'', __time__) as t, count(*) as cnt FROM slow_query_log GROUP BY t ORDER BY t",
     "gpt-4o", 0.78, '{"input_tokens": 16, "output_tokens": 50}'),
]


def step_write_data(client: Client, workspace: str):
    print("\n" + "=" * 60)
    print("Step 4: 写入数据（6 条样本）")
    print("=" * 60)

    values_parts = []
    for inp, out, model, score, meta in SAMPLE_DATA:
        values_parts.append(f"    ('{inp}', '{out}', '{model}', {score}, '{meta}')")

    sql = f"""
    INSERT INTO {DATASET_NAME} (input, output, model, score, metadata)
    VALUES
{',\n'.join(values_parts)}
    """

    result = execute_query(client, workspace, sql)
    meta = result["meta"]
    print(f"✓ 写入成功, affected_rows: {meta['affectedRows']}, elapsed: {meta['elapsedMillisecond']}ms")

    print("  等待数据索引生效...")
    time.sleep(1)


# ---------------------------------------------------------------------------
# Step 5: 查询数据（全文检索 / 语义搜索 / SQL / 组合查询）
# ---------------------------------------------------------------------------

def step_query_data(client: Client, workspace: str):
    print("\n" + "=" * 60)
    print("Step 5: 查询数据")
    print("=" * 60)
    print()
    print("Dataset 查询支持四种模式：")
    print("  1. 全文检索    field:keyword                              关键词匹配")
    print("  2. 语义搜索    similarity() / semantic_distance()         向量语义检索")
    print("  3. SQL 分析    SELECT ... FROM dataset_name ...           标准 SQL")
    print("  4. 组合查询    <检索条件> | <SQL 语句>                     自由组合")
    print()
    print("语义搜索提供两种形式：")
    print("  ▸ 检索语法: similarity(field, 'text') < threshold   → 用于 | 左侧检索条件")
    print("  ▸ SQL 函数: semantic_distance(field, 'text')        → 用于 | 右侧 SQL 语句中")

    def show(label, q):
        """打印查询语句并执行"""
        print(f"\n--- {label} ---")
        lines = [l.strip() for l in q.strip().splitlines() if l.strip()]
        if len(lines) == 1:
            print(f"  Query: {lines[0]}")
        else:
            print("  Query:")
            for line in lines:
                print(f"    {line}")
        return execute_query(client, workspace, q)

    # ── 5a. 全文检索 ──
    result = show("5a. 全文检索", "input:错误")
    data = result.get("data", [])
    print(f"  命中 {len(data)} 条:")
    for item in data:
        print(f"    [{item.get('score', '')}] {item.get('input', '')[:50]}")

    # ── 5b. 语义搜索（两种形式） ──

    # 形式一: 检索语法 similarity()，用于 | 左侧
    result = show(
        "5b-1. 语义搜索 — 检索语法 similarity()",
        f"similarity(input, '日志查询统计') < 0.3",
    )
    data = result.get("data", [])
    print(f"  命中 {len(data)} 条:")
    for item in data:
        print(f"    [{item.get('score', '')}] {item.get('input', '')[:50]}")

    # 形式二: SQL 函数 semantic_distance()，用于 | 右侧 SQL
    result = show(
        "5b-2. 语义搜索 — SQL 函数 semantic_distance()",
        f"SELECT input, semantic_distance(input, '日志查询统计') AS similarity\n"
        f"FROM {DATASET_NAME}\n"
        f"WHERE semantic_distance(input, '日志查询统计') < 0.4\n"
        f"ORDER BY semantic_distance(input, '日志查询统计') ASC",
    )
    data = result.get("data", [])
    print(f"  返回 {len(data)} 条 (按相似度升序):")
    for item in data:
        sim = item.get('similarity', '')
        try:
            sim = f"{float(sim):.4f}"
        except (ValueError, TypeError):
            pass
        print(f"    [similarity={sim}] {item.get('input', '')[:40]}")

    # ── 5c. SQL 分析 ──
    result = show(
        "5c. SQL 分析 (GROUP BY 聚合)",
        f"SELECT model, count(*) AS total, avg(score) AS avg_score\n"
        f"FROM {DATASET_NAME}\n"
        f"GROUP BY model\n"
        f"ORDER BY total DESC",
    )
    data = result.get("data", [])
    print(f"  分析结果 ({len(data)} 组):")
    print(f"    {'模型':<20s} {'数量':>4s}  {'平均分':>6s}")
    print(f"    {'─' * 20} {'─' * 4}  {'─' * 6}")
    for row in data:
        model = str(row.get('model', ''))
        total = str(row.get('total', ''))
        try:
            avg_s = f"{float(row.get('avg_score', 0)):.2f}"
        except (ValueError, TypeError):
            avg_s = str(row.get('avg_score', ''))
        print(f"    {model:<20s} {total:>4s}  {avg_s:>6s}")

    # ── 5d. 组合查询 (全文 + SQL + semantic_distance) ──
    result = show(
        "5d. 组合查询 (全文 + SQL + semantic_distance)",
        f"model:qwen-plus\n"
        f"| SELECT input, score FROM {DATASET_NAME}\n"
        f"  WHERE semantic_distance(input, '日志分析') < 0.5\n"
        f"  ORDER BY score DESC LIMIT 3",
    )
    data = result.get("data", [])
    print(f"  返回 {len(data)} 条:")
    for i, row in enumerate(data, 1):
        print(f"    {i}. [score={row.get('score', '')}] {row.get('input', '')[:50]}")

    # ── 5e. 分页查询 (LIMIT offset, count) ──
    print(f"\n--- 5e. 分页查询 (LIMIT offset, count) ---")
    page_size = 2
    for page in range(1, 4):
        skip = (page - 1) * page_size
        sql = f"SELECT id, input FROM {DATASET_NAME} LIMIT {skip}, {page_size}"
        print(f"  Query: {sql}")
        result = execute_query(client, workspace, sql)
        count = len(result.get("data", []))
        print(f"    第 {page} 页: {count} 条")


# ---------------------------------------------------------------------------
# Step 6: 更新数据
# ---------------------------------------------------------------------------

def step_update_data(client: Client, workspace: str) -> str:
    print("\n" + "=" * 60)
    print("Step 6: 更新数据")
    print("=" * 60)

    # 先查询获取一条数据的 id
    sql = f"SELECT id, input, score FROM {DATASET_NAME} LIMIT 1"
    result = execute_query(client, workspace, sql)
    data = result.get("data", [])
    if not data:
        print("  无数据可更新")
        return ""

    doc_id = data[0]["id"]
    old_score = data[0].get("score", "?")
    print(f"  目标: id={doc_id}, 原始 score={old_score}")

    update_sql = f"""
    UPDATE {DATASET_NAME}
    SET score = 0.99
    WHERE id = '{doc_id}'
    """
    result = execute_query(client, workspace, update_sql)
    meta = result["meta"]
    print(f"✓ 更新成功, affected_rows: {meta['affectedRows']}, score: {old_score} → 0.99")
    return doc_id


# ---------------------------------------------------------------------------
# Step 7: 删除数据
# ---------------------------------------------------------------------------

def step_delete_data(client: Client, workspace: str, doc_id: str):
    print("\n" + "=" * 60)
    print("Step 7: 删除数据")
    print("=" * 60)

    if not doc_id:
        print("  无可删除的 id，跳过")
        return

    sql = f"DELETE FROM {DATASET_NAME} WHERE id = '{doc_id}'"
    result = execute_query(client, workspace, sql)
    meta = result["meta"]
    print(f"✓ 删除成功, id={doc_id}, affected_rows: {meta['affectedRows']}")


# ---------------------------------------------------------------------------
# Step 8: 更新 Dataset 描述
# ---------------------------------------------------------------------------

def step_update_dataset(client: Client, workspace: str):
    print("\n" + "=" * 60)
    print("Step 8: 更新 Dataset 描述")
    print("=" * 60)

    request = UpdateDatasetRequest(description="快速入门示例数据集（已完成测试）")
    response = client.update_dataset(workspace, DATASET_NAME, request)
    print(f"✓ 描述已更新 (requestId: {response.body.request_id})")


# ---------------------------------------------------------------------------
# Step 9: 删除 Dataset（清理）
# ---------------------------------------------------------------------------

def step_delete_dataset(client: Client, workspace: str):
    print("\n" + "=" * 60)
    print("Step 9: 删除 Dataset（清理资源）")
    print("=" * 60)

    response = client.delete_dataset(workspace, DATASET_NAME)
    print(f"✓ Dataset '{DATASET_NAME}' 已删除 (requestId: {response.body.request_id})")


# ---------------------------------------------------------------------------
# 主流程
# ---------------------------------------------------------------------------

def main():
    client = get_client()
    workspace = os.getenv("ALIBABA_CLOUD_CMS_WORKSPACE")

    print(f"AgentLoop Dataset 快速入门")
    print(f"Workspace: {workspace}")
    print(f"Dataset:   {DATASET_NAME}")

    step_create_dataset(client, workspace)      # 1. 创建
    step_get_dataset(client, workspace)          # 2. 查看详情
    step_list_datasets(client, workspace)        # 3. 列举
    step_write_data(client, workspace)           # 4. 写入
    step_query_data(client, workspace)           # 5. 查询
    # doc_id = step_update_data(client, workspace) # 6. 更新数据
    # step_delete_data(client, workspace, doc_id)  # 7. 删除数据
    # step_update_dataset(client, workspace)       # 8. 更新 Dataset
    # step_delete_dataset(client, workspace)       # 9. 删除 Dataset

    print("\n" + "=" * 60)
    print("全生命周期演示完成")
    print("=" * 60)


if __name__ == "__main__":
    main()
