from flask import Flask, render_template
import os
import uuid
import json
import socketio
import websocket
import threading
from collections import defaultdict
import traceback

app = Flask(__name__)
sio = socketio.Server(async_mode='threading', cors_allowed_origins='*')
app.wsgi_app = socketio.Middleware(sio, app.wsgi_app)

class ClientManager:
    def __init__(self):
        self.clients = defaultdict(dict)  # sid -> client data
        self.lock = threading.Lock()

    def create_client(self, sid, voice):
        with self.lock:
            # 清理旧的客户端连接
            if sid in self.clients:
                try:
                    self.clients[sid]['client'].close()
                except:
                    pass
                del self.clients[sid]
            
            # 创建新的客户端
            try:
                client = TTSClient(os.getenv("DASHSCOPE_API_KEY"), 
                                  "wss://dashscope.aliyuncs.com/api-ws/v1/inference/",
                                  voice=voice,
                                  sid=sid)
                self.clients[sid] = {
                    'client': client,
                    'voice': voice
                }
                client.run()
                return client
            except Exception as e:
                print(f"创建TTS客户端失败: {e}")
                return None

    def get_client(self, sid):
        with self.lock:
            return self.clients.get(sid, None)

    def remove_client(self, sid):
        with self.lock:
            if sid in self.clients:
                try:
                    self.clients[sid]['client'].close()
                except:
                    pass
                del self.clients[sid]

client_manager = ClientManager()

class TTSClient:
    def __init__(self, api_key, uri, voice="longanyang", sid=None):
        self.api_key = api_key
        self.uri = uri
        self.voice = voice
        self.sid = sid  # 客户端会话ID
        self.task_id = str(uuid.uuid4())
        self.ws = None
        self.task_started = False
        self.task_finished = False
        self.task_started_event = threading.Event()  # 任务启动事件

    def on_open(self, ws):
        try:
            run_task_cmd = {
                "header": {
                    "action": "run-task",
                    "task_id": self.task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "task_group": "audio",
                    "task": "tts",
                    "function": "SpeechSynthesizer",
                    "model": "cosyvoice-v3-flash",
                    "parameters": {
                        "text_type": "PlainText",
                        "voice": self.voice,
                        "format": "mp3",
                        "sample_rate": 22050,
                        "volume": 50,
                        "rate": 1,
                        "pitch": 1
                    },
                    "input": {}
                }
            }
            ws.send(json.dumps(run_task_cmd))
            print(f"已发送 run-task 指令 (sid: {self.sid})")
        except Exception as e:
            print(f"发送run-task指令失败: {e}")
            self.send_error(f"发送指令失败: {e}")

    def on_message(self, ws, message):
        try:
            if isinstance(message, str):
                msg_json = json.loads(message)
                header = msg_json.get("header", {})
                event = header.get("event", "")

                if event == "task-started":
                    self.task_started = True
                    self.task_started_event.set()  # 设置事件标志
                    sio.emit('audio_start', room=self.sid)  # 通知客户端开始接收
                elif event in ["task-finished", "task-failed"]:
                    self.task_finished = True
                    # 发送结束事件给前端
                    sio.emit('audio_end', room=self.sid)
                    self.close(ws)
                    print(f"任务完成 (sid: {self.sid})")
            else:
                # 直接发送音频块到指定客户端
                sio.emit('audio_chunk', {'data': message}, room=self.sid)
        except Exception as e:
            print(f"处理消息失败: {e}")
            self.send_error(f"处理消息失败: {e}")

    def on_error(self, ws, error):
        print(f"WebSocket 出错 (sid: {self.sid}): {error}")
        self.send_error(f"WebSocket错误: {error}")

    def on_close(self, ws, close_status_code, close_msg):
        print(f"WebSocket 已关闭 (sid: {self.sid}): {close_msg} ({close_status_code})")
        client_manager.remove_client(self.sid)
        self.task_started_event.set()  # 确保等待的线程不会阻塞

    def send_continue_task(self, text):
        if not self.ws:
            print("WebSocket 未连接，无法发送数据")
            return False
            
        try:
            cmd = {
                "header": {
                    "action": "continue-task",
                    "task_id": self.task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "input": {
                        "text": text
                    }
                }
            }
            self.ws.send(json.dumps(cmd))
            return True
        except Exception as e:
            print(f"发送continue-task失败: {e}")
            self.send_error(f"发送文本失败: {e}")
            return False

    def send_finish_task(self):
        if not self.ws:
            print("WebSocket 未连接，无法发送完成指令")
            return False
            
        try:
            cmd = {
                "header": {
                    "action": "finish-task",
                    "task_id": self.task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "input": {}
                }
            }
            self.ws.send(json.dumps(cmd))
            print(f"已发送 finish-task 指令 (sid: {self.sid})")
            return True
        except Exception as e:
            print(f"发送finish-task失败: {e}")
            self.send_error(f"发送结束指令失败: {e}")
            return False

    def close(self, ws):
        try:
            if ws and ws.sock and ws.sock.connected:
                ws.close()
        except:
            pass

    def run(self):
        try:
            header = {"Authorization": f"bearer {self.api_key}", "X-DashScope-DataInspection": "enable"}
            self.ws = websocket.WebSocketApp(
                self.uri,
                header=header,
                on_open=self.on_open,
                on_message=self.on_message,
                on_error=self.on_error,
                on_close=self.on_close
            )
            thread = threading.Thread(target=self.ws.run_forever, daemon=True)
            thread.start()
        except Exception as e:
            print(f"启动WebSocket失败: {e}")
            self.send_error(f"连接失败: {e}")

    def send_error(self, message):
        try:
            sio.emit('synthesis_error', {'message': message}, room=self.sid)
        except:
            pass
        finally:
            client_manager.remove_client(self.sid)

@app.route('/')
def index():
    return render_template('my_cosyvoice_client.html')

# Socket.IO 连接事件
@sio.on('connect')
def handle_connect(sid, environ):
    print(f"客户端已连接: {sid}")

# Socket.IO 断开连接事件
@sio.on('disconnect')
def handle_disconnect(sid):
    print(f"客户端断开连接: {sid}")
    client_manager.remove_client(sid)

# Socket.IO 合成请求事件
@sio.on('synthesize')
def handle_synthesize(sid, data):
    input_text = data.get('input', '')
    voice = data.get('voice', 'longxiaochun_v2')
    
    if not input_text:
        print(f"收到空文本，忽略 (sid: {sid})")
        sio.emit('synthesis_error', {'message': '输入文本不能为空'}, room=sid)
        return
        
    print(f"收到合成请求 (sid: {sid}): {input_text[:20]}... 音色: {voice}")
    
    try:
        # 为当前客户端创建语音合成客户端
        client = client_manager.create_client(sid, voice)
        if not client:
            sio.emit('synthesis_error', {'message': '创建合成客户端失败'}, room=sid)
            return
        
        # 等待任务启动（最多10秒）
        if not client.task_started_event.wait(10):
            print(f"任务启动超时 (sid: {sid})")
            sio.emit('synthesis_error', {'message': '任务启动超时'}, room=sid)
            return
        
        # 发送文本进行合成（一次性发送全部文本）
        # if not client.send_continue_task(input_text):
        #     sio.emit('synthesis_error', {'message': '发送文本失败'}, room=sid)
        #     return
        
        # 发送文本进行合成（按句子边界切割，连续分隔符归属前句）
        # 定义句子分隔符
        SENTENCE_DELIMITERS = ['.', '?', '!', '。', '？', '！', '\n']
        
        # 初始化切割位置
        start_index = 0
        fragments = []
        
        # 遍历整个文本
        i = 0
        while i < len(input_text):
            # 检查当前字符是否是句子分隔符
            if input_text[i] in SENTENCE_DELIMITERS:
                # 找到句子结束位置（包含当前分隔符）
                end_index = i + 1
                
                # 检查后面是否有连续的分隔符
                while end_index < len(input_text) and input_text[end_index] in SENTENCE_DELIMITERS:
                    end_index += 1
                
                # 添加到片段列表（包含所有连续分隔符）
                fragments.append(input_text[start_index:end_index])
                
                # 更新起始位置
                start_index = end_index
                
                # 跳过已处理的分隔符
                i = end_index - 1  # 因为循环结束前会i++，所以减1
            
            i += 1
        
        # 添加最后一个片段（如果有剩余文本）
        if start_index < len(input_text):
            fragments.append(input_text[start_index:])
        
        # 发送所有文本片段
        for fragment in fragments:
            if not client.send_continue_task(fragment):
                sio.emit('synthesis_error', {'message': '发送文本失败'}, room=sid)
                return


        
        # 发送结束任务指令
        if not client.send_finish_task():
            sio.emit('synthesis_error', {'message': '发送结束指令失败'}, room=sid)
            return
    except Exception as e:
        error_msg = f"处理请求失败: {str(e)}"
        print(f"{error_msg} (sid: {sid})\n{traceback.format_exc()}")
        sio.emit('synthesis_error', {'message': error_msg}, room=sid)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=9000)
