import base64
import json
from io import BytesIO

import requests
from openai import OpenAI
from sseclient import SSEClient


class Engine:
    def __init__(
        self,
        asr_base_url: str,
        asr_api_key: str,
        llm_base_url: str,
        llm_api_key: str,
        llm_guide: str,
        tts_base_url: str,
        tts_api_key: str,
        tts_reference_audio_id: str,
    ):
        # asr
        self.asr_base_url = asr_base_url
        self.asr_api_key = asr_api_key

        # llm
        self._client = OpenAI(
            api_key=llm_api_key,
            base_url=llm_base_url,
        )
        models = self._client.models.list()
        self.llm_model = models.data[0].id
        self.llm_guide = llm_guide

        # tts
        self.tts_base_url = tts_base_url
        self.tts_api_key = tts_api_key
        self.tts_reference_audio_id = tts_reference_audio_id

    async def call_asr(self, audio_buffer: BytesIO):
        audio_path = audio_buffer
        url = self.asr_base_url
        token = self.asr_api_key

        def encode_file_to_base64(audio_buffer):
            # with open(file_path, "rb") as f:
            encoded_string = base64.b64encode(audio_buffer.getvalue()).decode("utf-8")
            data = json.dumps({"audio": encoded_string})
            return data

        request_body = str(encode_file_to_base64(audio_path))

        headers = {"Authorization": token}
        resp = requests.post(url=url, headers=headers, data=request_body)

        # print(resp.content.decode())
        # print("status code:", resp.status_code)
        return json.loads(resp.content.decode())[0]["text"]

    async def call_llm(self, text: str):
        messages_v1 = [
            {"role": "system", "content": self.llm_guide},
            {"role": "user", "content": text},
        ]
        stream = self._client.chat.completions.create(
            model=self.llm_model,
            messages=messages_v1,
            max_completion_tokens=100,
            temperature=0.7,
            top_p=0.8,
            presence_penalty=1.5,
            stream=True,
            timeout=60,
            extra_body={"chat_template_kwargs": {"enable_thinking": False}},
        )
        # result = stream.choices[0].message.content
        # return result
        for chunk in stream:
            token = chunk.choices[0].delta.content
            # print(f"{token=}")
            yield token

    async def call_tts(self, text: str, format: str = "mp3", sample_rate: int = 24000):
        url = f"{self.tts_base_url}/api/v1/audio/speech"

        with requests.post(
            url,
            headers={
                "Authorization": f"Bearer {self.tts_api_key}",
                "Content-Type": "application/json",
            },
            json={
                "model": "CosyVoice2-0.5B",
                "input": {
                    "mode": "fast_replication",
                    "reference_audio_id": self.tts_reference_audio_id,
                    "text": text,
                    "speed": 1,
                    "output_format": format,
                    "sample_rate": sample_rate,
                    "bit_rate": "48k",
                    "volume": 2.0,
                    "instruct": "用四川话说",
                },
                "stream": True,
            },
            stream=True,
            timeout=10,
        ) as response:
            if response.status_code != 200:
                print(response.text)
                exit(1)

            messages = SSEClient(response)

            for i, msg in enumerate(messages.events()):
                # print(f"Event: {msg.event}, Data: {msg.data}")
                data = json.loads(msg.data)
                encode_buffer = data["output"]["audio"]["data"]
                # decode_buffer = base64.b64decode(encode_buffer)
                yield encode_buffer
