import argparse
import asyncio
import base64
from io import BytesIO

import gradio as gr
import numpy as np
import soundfile as sf

import config
from engine import Engine
from utils import logging

engine = Engine(
    asr_base_url=config.ASR_BASE_URL,
    asr_api_key=config.ASR_API_KEY,
    llm_base_url=config.LLM_BASE_URL,
    llm_api_key=config.LLM_API_KEY,
    llm_guide=config.LLM_GUIDE,
    tts_base_url=config.TTS_BASE_URL,
    tts_api_key=config.TTS_API_KEY,
    tts_reference_audio_id=config.TTS_REFERENCE_AUDIO_ID,
)


def parse_args():
    parser = argparse.ArgumentParser(description="")

    parser.add_argument("--ssl", action="store_true", help="enable SSL")
    parser.add_argument(
        "--port", type=int, default=7860, help="Port to run the server on"
    )

    args = parser.parse_args()
    return args


def save_audio(audio_input):
    if isinstance(audio_input, tuple):
        # Gradio with type="numpy"
        sample_rate, audio_data = audio_input
        # 注意：Gradio 的 numpy 音频通常是 float32 [-1, 1]
        # soundfile 默认写 float32 为 WAV，但有些播放器不支持
        # 可选：转为 int16
        if audio_data.dtype == np.float32:
            audio_data = np.clip(audio_data, -1, 1) * 32767
            audio_data = audio_data.astype(np.int16)
        buf = BytesIO()
        sf.write(buf, audio_data, sample_rate, format="WAV")
        return buf
    else:
        return None


def call_asr(audio):
    logging.info("-" * 50)
    audio_buffer = save_audio(audio)
    if audio_buffer is None:
        return "听不清, 请重讲一遍", gr.update(value=None)
    asr_text = asyncio.run(engine.call_asr(audio_buffer))
    logging.info(f"{asr_text=}")
    return asr_text, gr.update(value=None)


async def call_llm(text):
    llm_text = ""
    async for token in engine.call_llm(text):
        llm_text += token
        await asyncio.sleep(0.05)
        yield llm_text, None

    logging.info(f"{llm_text=}")

    filename = "output.mp3"
    with open(filename, "wb") as f:
        async for chunk in engine.call_tts(llm_text):
            decode_buffer = base64.b64decode(chunk)
            f.write(decode_buffer)
    yield llm_text, filename


def run(launch_args):
    with gr.Blocks(title="Record & Save WAV") as app:
        audio = gr.Audio(
            label="Mic Recording", format="wav", streaming=False, sources=["microphone"]
        )
        asr_textbox = gr.Textbox(label="ASR Result")
        llm_textbox = gr.Textbox(label="LLM Result", lines=10)
        tts_audio = gr.Audio(type="filepath", label="TTS Audio", autoplay=True)

        audio.stop_recording(fn=call_asr, inputs=audio, outputs=[asr_textbox, audio])

        asr_textbox.change(
            fn=call_llm, inputs=asr_textbox, outputs=[llm_textbox, tts_audio]
        )

    app.launch(server_name="0.0.0.0", **launch_args)


if __name__ == "__main__":
    args = parse_args()

    launch_args = {"server_port": args.port}

    if args.ssl:
        launch_args.update(
            {
                "ssl_keyfile": "asset/test.key",
                "ssl_certfile": "asset/test.crt",
                "ssl_verify": False,
            }
        )

    run(launch_args)
