from flask import Flask, render_template
import os
import uuid
import json
import socketio
import websocket
import threading
from collections import defaultdict
import traceback

app = Flask(__name__)
sio = socketio.Server(async_mode='threading', cors_allowed_origins='*')
app.wsgi_app = socketio.Middleware(sio, app.wsgi_app)

class ClientManager:
    def __init__(self):
        self.clients = defaultdict(dict)  # sid -> client data
        self.lock = threading.Lock()

    def create_client(self, sid, voice):
        with self.lock:
            # Clean up old client connections
            if sid in self.clients:
                try:
                    self.clients[sid]['client'].close()
                except:
                    pass
                del self.clients[sid]

            # Create new client
            try:
                client = TTSClient(os.getenv("DASHSCOPE_API_KEY"),
                                  "wss://dashscope.aliyuncs.com/api-ws/v1/inference/",
                                  voice=voice,
                                  sid=sid)
                self.clients[sid] = {
                    'client': client,
                    'voice': voice
                }
                client.run()
                return client
            except Exception as e:
                print(f"Failed to create TTS client: {e}")
                return None

    def get_client(self, sid):
        with self.lock:
            return self.clients.get(sid, None)

    def remove_client(self, sid):
        with self.lock:
            if sid in self.clients:
                try:
                    self.clients[sid]['client'].close()
                except:
                    pass
                del self.clients[sid]

client_manager = ClientManager()

class TTSClient:
    def __init__(self, api_key, uri, voice="longanyang", sid=None):
        self.api_key = api_key
        self.uri = uri
        self.voice = voice
        self.sid = sid  # Client session ID
        self.task_id = str(uuid.uuid4())
        self.ws = None
        self.task_started = False
        self.task_finished = False
        self.task_started_event = threading.Event()  # Task start event

    def on_open(self, ws):
        try:
            run_task_cmd = {
                "header": {
                    "action": "run-task",
                    "task_id": self.task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "task_group": "audio",
                    "task": "tts",
                    "function": "SpeechSynthesizer",
                    "model": "cosyvoice-v3-flash",
                    "parameters": {
                        "text_type": "PlainText",
                        "voice": self.voice,
                        "format": "mp3",
                        "sample_rate": 22050,
                        "volume": 50,
                        "rate": 1,
                        "pitch": 1
                    },
                    "input": {}
                }
            }
            ws.send(json.dumps(run_task_cmd))
            print(f"Sent run-task command (sid: {self.sid})")
        except Exception as e:
            print(f"Failed to send run-task command: {e}")
            self.send_error(f"Failed to send command: {e}")

    def on_message(self, ws, message):
        try:
            if isinstance(message, str):
                msg_json = json.loads(message)
                header = msg_json.get("header", {})
                event = header.get("event", "")

                if event == "task-started":
                    self.task_started = True
                    self.task_started_event.set()  # Set event flag
                    sio.emit('audio_start', room=self.sid)  # Notify client to start receiving
                elif event in ["task-finished", "task-failed"]:
                    self.task_finished = True
                    # Send end event to frontend
                    sio.emit('audio_end', room=self.sid)
                    self.close(ws)
                    print(f"Task completed (sid: {self.sid})")
            else:
                # Send audio chunk directly to specified client
                sio.emit('audio_chunk', {'data': message}, room=self.sid)
        except Exception as e:
            print(f"Failed to process message: {e}")
            self.send_error(f"Failed to process message: {e}")

    def on_error(self, ws, error):
        print(f"WebSocket error (sid: {self.sid}): {error}")
        self.send_error(f"WebSocket error: {error}")

    def on_close(self, ws, close_status_code, close_msg):
        print(f"WebSocket closed (sid: {self.sid}): {close_msg} ({close_status_code})")
        client_manager.remove_client(self.sid)
        self.task_started_event.set()  # Ensure waiting threads are not blocked

    def send_continue_task(self, text):
        if not self.ws:
            print("WebSocket not connected, cannot send data")
            return False

        try:
            cmd = {
                "header": {
                    "action": "continue-task",
                    "task_id": self.task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "input": {
                        "text": text
                    }
                }
            }
            self.ws.send(json.dumps(cmd))
            return True
        except Exception as e:
            print(f"Failed to send continue-task: {e}")
            self.send_error(f"Failed to send text: {e}")
            return False

    def send_finish_task(self):
        if not self.ws:
            print("WebSocket not connected, cannot send finish command")
            return False

        try:
            cmd = {
                "header": {
                    "action": "finish-task",
                    "task_id": self.task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "input": {}
                }
            }
            self.ws.send(json.dumps(cmd))
            print(f"Sent finish-task command (sid: {self.sid})")
            return True
        except Exception as e:
            print(f"Failed to send finish-task: {e}")
            self.send_error(f"Failed to send finish command: {e}")
            return False

    def close(self, ws):
        try:
            if ws and ws.sock and ws.sock.connected:
                ws.close()
        except:
            pass

    def run(self):
        try:
            header = {"Authorization": f"bearer {self.api_key}", "X-DashScope-DataInspection": "enable"}
            self.ws = websocket.WebSocketApp(
                self.uri,
                header=header,
                on_open=self.on_open,
                on_message=self.on_message,
                on_error=self.on_error,
                on_close=self.on_close
            )
            thread = threading.Thread(target=self.ws.run_forever, daemon=True)
            thread.start()
        except Exception as e:
            print(f"Failed to start WebSocket: {e}")
            self.send_error(f"Connection failed: {e}")

    def send_error(self, message):
        try:
            sio.emit('synthesis_error', {'message': message}, room=self.sid)
        except:
            pass
        finally:
            client_manager.remove_client(self.sid)

@app.route('/')
def index():
    return render_template('my_cosyvoice_client.html')

# Socket.IO connect event
@sio.on('connect')
def handle_connect(sid, environ):
    print(f"Client connected: {sid}")

# Socket.IO disconnect event
@sio.on('disconnect')
def handle_disconnect(sid):
    print(f"Client disconnected: {sid}")
    client_manager.remove_client(sid)

# Socket.IO synthesis request event
@sio.on('synthesize')
def handle_synthesize(sid, data):
    input_text = data.get('input', '')
    voice = data.get('voice', 'longxiaochun_v2')

    if not input_text:
        print(f"Received empty text, ignoring (sid: {sid})")
        sio.emit('synthesis_error', {'message': 'Input text cannot be empty'}, room=sid)
        return

    print(f"Received synthesis request (sid: {sid}): {input_text[:20]}... Voice: {voice}")

    try:
        # Create speech synthesis client for current client
        client = client_manager.create_client(sid, voice)
        if not client:
            sio.emit('synthesis_error', {'message': 'Failed to create synthesis client'}, room=sid)
            return

        # Wait for task to start (maximum 10 seconds)
        if not client.task_started_event.wait(10):
            print(f"Task start timeout (sid: {sid})")
            sio.emit('synthesis_error', {'message': 'Task start timeout'}, room=sid)
            return

        # Send text for synthesis (split by sentence boundaries, consecutive delimiters belong to previous sentence)
        # Define sentence delimiters
        SENTENCE_DELIMITERS = ['.', '?', '!', '。', '?', '!', '\n']

        # Initialize split position
        start_index = 0
        fragments = []

        # Traverse entire text
        i = 0
        while i < len(input_text):
            # Check if current character is a sentence delimiter
            if input_text[i] in SENTENCE_DELIMITERS:
                # Find sentence end position (including current delimiter)
                end_index = i + 1

                # Check if there are consecutive delimiters after
                while end_index < len(input_text) and input_text[end_index] in SENTENCE_DELIMITERS:
                    end_index += 1

                # Add to fragment list (including all consecutive delimiters)
                fragments.append(input_text[start_index:end_index])

                # Update start position
                start_index = end_index

                # Skip processed delimiters
                i = end_index - 1  # Subtract 1 because loop will i++

            i += 1

        # Add last fragment (if there's remaining text)
        if start_index < len(input_text):
            fragments.append(input_text[start_index:])

        # Send all text fragments
        for fragment in fragments:
            if not client.send_continue_task(fragment):
                sio.emit('synthesis_error', {'message': 'Failed to send text'}, room=sid)
                return



        # Send finish task command
        if not client.send_finish_task():
            sio.emit('synthesis_error', {'message': 'Failed to send finish command'}, room=sid)
            return
    except Exception as e:
        error_msg = f"Request processing failed: {str(e)}"
        print(f"{error_msg} (sid: {sid})\n{traceback.format_exc()}")
        sio.emit('synthesis_error', {'message': error_msg}, room=sid)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=9000)
