import base64
import json
import uuid
from io import BytesIO

from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import JSONResponse
from sse_starlette import EventSourceResponse

import config
from engine import Engine
from utils import logging

engine = Engine(
    asr_base_url=config.ASR_BASE_URL,
    asr_api_key=config.ASR_API_KEY,
    llm_base_url=config.LLM_BASE_URL,
    llm_api_key=config.LLM_API_KEY,
    llm_guide=config.LLM_GUIDE,
    tts_base_url=config.TTS_BASE_URL,
    tts_api_key=config.TTS_API_KEY,
    tts_reference_audio_id=config.TTS_REFERENCE_AUDIO_ID,
)

app = FastAPI()


@app.post("/process_alt")
async def process_alt(file: UploadFile = File, metadata: str = Form()):
    request_uuid = str(uuid.uuid4())

    metadata = json.loads(metadata)
    debug = metadata.get("debug", False)

    # parameters check
    if file.filename.split(".")[-1] not in ("wav", "mp3"):
        return JSONResponse(
            status_code=400,
            content={
                "message": "file name suffix must be one of the following: .wav .mp3",
                "request_id": request_uuid,
            },
        )

    audio_buffer = BytesIO(await file.read())
    if debug:
        input_audio_filename = f"{request_uuid}_input_{file.filename}"
        with open(input_audio_filename, "wb") as f:
            f.write(audio_buffer.getvalue())

    # asr
    text = await engine.call_asr(audio_buffer)
    logging.info(f"{request_uuid=} asr_result: {text}")

    # llm
    llm_text = ""
    async for token in engine.call_llm(text):
        llm_text += token
    logging.info(f"{request_uuid=} llm_result: {llm_text}")

    async def stream_handler():
        format = metadata["format"]
        sample_rate = metadata["sample_rate"]

        chunks = []
        async for chunk in engine.call_tts(llm_text, format, sample_rate):
            ret_data = {
                "output": {
                    "finish_reason": None,
                    "audio": {
                        "data": chunk,
                    },
                },
                "error": None,
                "request_id": request_uuid,
            }
            chunks.append(chunk)

            yield json.dumps(ret_data, ensure_ascii=False)

        if debug:
            output_audio_filename = f"{request_uuid}_output.{format}"
            with open(output_audio_filename, "wb") as f:
                for chunk in chunks:
                    decode_buffer = base64.b64decode(chunk)
                    f.write(decode_buffer)

    return EventSourceResponse(stream_handler())
