#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import os
import time
import uuid

from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tea_openapi import utils_models as open_api_util_models
from alibabacloud_tea_openapi.client import Client as OpenApiClient
from alibabacloud_tea_util.models import RuntimeOptions

markdown_code_languages = [
    # 编程语言
    "python",
    "javascript",
    "js",
    "typescript",
    "ts",
    "java",
    "c",
    "cpp",
    "csharp",
    "cs",
    "go",
    "golang",
    "rust",
    "ruby",
    "php",
    "swift",
    "kotlin",
    "scala",
    "r",
    "perl",
    "haskell",
    "lua",
    "matlab",
    "fortran",
    "objective-c",
    "objc",
    "dart",
    "elixir",
    "erlang",
    "clojure",
    "fsharp",
    "vbnet",
    "assembly",
    "asm",

    # 脚本与配置
    "bash",
    "shell",
    "zsh",
    "powershell",
    "ps1",
    "batch",
    "bat",
    "cmd",

    # 标记与数据格式
    "html",
    "xml",
    "svg",
    "mathml",
    "xhtml",
    "markdown",
    "md",
    "json",
    "yaml",
    "yml",
    "toml",
    "ini",
    "properties",
    "dotenv",
    "env",

    # 样式表
    "css",
    "scss",
    "sass",
    "less",
    "stylus",

    # 模板与 DSL
    "jinja2",
    "django",
    "liquid",
    "handlebars",
    "hbs",
    "mustache",
    "twig",
    "pug",
    "jade",

    # 数据库与查询语言
    "sql",
    "mysql",
    "pgsql",
    "plsql",
    "sqlite",
    "cql",

    # 其他常用
    "diff",
    "patch",
    "makefile",
    "dockerfile",
    "docker",
    "nginx",
    "apache",
    "http",
    "graphql",
    "protobuf",
    "terraform",
    "hcl",
    "log",
    "plaintext",
    "text",
    "ascii",
]


class ChatMessageParams(open_api_util_models.Params):
    def __init__(self):
        super().__init__()
        self.action = 'ChatMessages'
        self.version = '2025-05-07'
        self.protocol = 'HTTPS'
        self.method = 'POST'


class ChatMessagesStopParams(open_api_util_models.Params):
    def __init__(self):
        super().__init__()
        self.action = 'ChatMessagesTaskStop'
        self.version = '2025-05-07'
        self.protocol = 'HTTPS'
        self.method = 'POST'


class BaseEvent:
    def __init__(self, task_id, conversion_id):
        self.task_id = task_id
        self.conversion_id = conversion_id


class MessageEvent(BaseEvent):
    def __init__(self, task_id, conversion_id, text):
        super().__init__(task_id, conversion_id)
        self.text = text


class ToolCallStart(BaseEvent):
    def __init__(self, task_id, conversion_id, title, text, tool_call_id):
        super().__init__(task_id, conversion_id)
        self.title = title
        self.text = text
        self.tool_call_id = f"t{tool_call_id.split('-')[-1]}"


class ToolCallPending(BaseEvent):
    def __init__(self, task_id, conversion_id, title, text, tool_call_id):
        super().__init__(task_id, conversion_id)
        self.title = title
        self.text = text
        self.tool_call_id = f"t{tool_call_id.split('-')[-1]}"


class ToolCallEnd(BaseEvent):
    def __init__(self, task_id, conversion_id, title, text, tool_call_id):
        super().__init__(task_id, conversion_id)
        self.title = title
        self.text = text
        self.tool_call_id = f"t{tool_call_id.split('-')[-1]}"


class DocumentEvent(BaseEvent):
    def __init__(self, task_id, conversion_id, title, text):
        super().__init__(task_id, conversion_id)
        self.document_id = f"d{str(uuid.uuid4()).split('-')[-1]}"
        self.title = title
        self.text = text


class SubTaskStartEvent(BaseEvent):
    def __init__(self, task_id, conversion_id, title, text):
        super().__init__(task_id, conversion_id)
        self.subtask_id = f"s{title.replace('_', '')}".lower()[:20]
        self.title = title
        self.text = text


class SubTaskEndEvent(BaseEvent):
    def __init__(self, task_id, conversion_id, title, text):
        super().__init__(task_id, conversion_id)
        self.subtask_id = f"s{title.replace('_', '')}".lower()[:20]
        self.title = title
        self.text = text


class ChartEvent(BaseEvent):
    def __init__(self, task_id, conversion_id, title, x_field, y_field, data):
        super().__init__(task_id, conversion_id)
        self.title = title
        self.x_field = x_field
        self.y_field = y_field
        self.data = data


class RdsCopilot:
    def __init__(self):

        # 初始化OpenAPI配置
        config = open_api_models.Config(
            access_key_id=os.getenv('ACCESS_KEY_ID'),
            access_key_secret=os.getenv('ACCESS_SECRET'),
            protocol='https',
            region_id='cn-hangzhou',
            endpoint='rdsai.aliyuncs.com',
            read_timeout=600_000,
            connect_timeout=10_000
        )
        self.app_id = 'app-iBuGU1VxEY42zrQRQfNAn3oj'
        self.client = OpenApiClient(config)
        self.code_mask_start = '```'
        self.code_mask_end = '```\n'

    def _handle_event(self, task_id, conversion_id, full_response):
        if full_response.startswith('```tool_call'):
            tool_call_msg = full_response[len('```tool_call\n'):-len("\n```")]
            tool_call_args = json.loads(tool_call_msg)
            tool_call_name = tool_call_args['tool_call_name']
            if tool_call_args['status'] == 'start':
                return ToolCallStart(
                    task_id, conversion_id,
                    title=tool_call_name,
                    text=tool_call_msg,
                    tool_call_id=tool_call_args['tool_call_id']
                )
            elif tool_call_args['status'] == 'pending':
                return ToolCallPending(
                    task_id, conversion_id,
                    title=tool_call_name,
                    text=tool_call_msg,
                    tool_call_id=tool_call_args['tool_call_id']
                )
            else:
                return ToolCallEnd(
                    task_id, conversion_id,
                    title=tool_call_name,
                    text=tool_call_msg,
                    tool_call_id=tool_call_args['tool_call_id']
                )
        elif full_response.startswith('```doc'):
            doc_msg = full_response[len('```doc\n'):-len("\n```")]
            return DocumentEvent(
                task_id, conversion_id,
                title="Documents",
                text=doc_msg
            )
        elif full_response.startswith('```subtask'):
            subtask_msg = full_response[len('```subtask\n'):-len("\n```")]
            subtask = json.loads(subtask_msg)
            if subtask['status'] == 'start':
                return SubTaskStartEvent(
                    task_id, conversion_id,
                    title=subtask['name'],
                    text=subtask_msg
                )
            elif subtask['status'] == 'end':
                return SubTaskEndEvent(
                    task_id, conversion_id,
                    title=subtask['name'],
                    text=subtask_msg
                )
        else:
            title = full_response.split('\n')[0][3:]
            text = full_response[len(title) + 3:-3]
            return ToolCallStart(
                task_id, conversion_id,
                title=title,
                text=text
            )

    def stop_task(self, task_id):
        # 发送停止请求
        stop_query_params = {
            'TaskId': task_id,
            'ApiId': self.app_id
        }
        stop_request = open_api_util_models.OpenApiRequest(query=stop_query_params)
        response = self.client.do_request(
            ChatMessagesStopParams(),
            stop_request,
            RuntimeOptions()
        )

    def chat(self, query, conversion_id=''):
        full_response = ""
        task_id = ""

        try:
            # 构造请求参数
            query_params = {
                'Query': query,
                'ConversationId': conversion_id,
                'ApiId': self.app_id
            }

            # 发起聊天消息请求
            late_msg_time = time.time()
            chat_message_params = ChatMessageParams()
            chat_message_request = open_api_util_models.OpenApiRequest(query=query_params)
            responses = self.client.call_sseapi(chat_message_params, chat_message_request, RuntimeOptions())
            max_language_size = max([len(language) for language in markdown_code_languages]) + 3
            for response in responses:
                # 解析响应
                response_body = json.loads(response.event.data)
                print(response_body)
                if 'TaskId' in response_body:
                    task_id = response_body['TaskId']
                if 'ConversionId' in response_body:
                    conversion_id = response_body['ConversionId']

                if 'Answer' in response_body and response_body['Answer']:
                    answer = response_body['Answer']
                    full_response += answer
                    # 响应中有代码块，立即推出之前的消息
                    if self.code_mask_start in full_response:
                        if full_response.index(self.code_mask_start) > 0:
                            yield MessageEvent(task_id, conversion_id,
                                               full_response[:full_response.index(self.code_mask_start)])
                            full_response = full_response[full_response.index(self.code_mask_start):]

                    # 当前在代码块内，判断是否为交互事件
                    if not full_response.startswith(self.code_mask_end) and full_response.startswith(
                            self.code_mask_start):
                        if len(full_response) < max_language_size:
                            continue
                        is_markdown_language = any(
                            full_response.startswith(f"```{language}") for language in markdown_code_languages)
                        # 标准Markdown语言，直接返回然后清空
                        if is_markdown_language:
                            yield MessageEvent(task_id, conversion_id, full_response)
                            full_response = ""
                        elif full_response.endswith(self.code_mask_end):
                            yield self._handle_event(task_id, conversion_id, full_response)
                            full_response = ""
                    # 每0.1秒推一次，避免接口限流
                    elif time.time() - late_msg_time > 0.1:
                        yield MessageEvent(task_id, conversion_id, full_response)
                        late_msg_time = time.time()
                        full_response = ""
            # 最后如果还有字符，直接推出
            if full_response:
                yield MessageEvent(task_id, conversion_id, full_response)
        except Exception as e:
            print(e)
            raise e
