import json
import ast
import math
import os

import requests
import xml.etree.ElementTree as ET

from openai import OpenAI
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageColor



additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]


def smart_resize_qwen2_5_vl(img,min_pixels=32 * 32 * 4,max_pixels=2560 * 32 * 32):

    """
    :param img:
    :param min_pixels: 图像的Token下限：一般为4个token的像素大小，
    :param max_pixels: 图像的Token上限
          关闭vl_high_resolution_images时，
                qwen-vl-max-0813以后和qwen-vl-plus-0815以后更新的模型：将宽高都调整为32的整数倍，图像的Token上限为2560
                其余模型：图像的Token上限为1280
          开启vl_high_resolution_images时，所有模型的图像的Token上限为16384
    return: h_bar, w_bar：缩放后的图像长宽
    """

    # 获取图片的原始尺寸
    img = img
    width, height = img.size

    # qwen-vl-max-0813以后和qwen-vl-plus-0815以后更新的模型：将宽高都调整为32的整数倍
    # 其余模型：将宽高都调整为28的整数倍
    h_bar = round(height / 32) * 32
    w_bar = round(width / 32) * 32

    # 对图像进行缩放处理，调整像素的总数在范围[min_pixels,max_pixels]内
    if h_bar * w_bar > max_pixels:
        # 计算缩放因子beta，使得缩放后的图像总像素数不超过max_pixels
        beta = math.sqrt((height * width) / max_pixels)
        # 重新计算调整后的宽高，对于qwen-vl-max-0815以后及qwen-vl-plus-0815以后更新的模型：将宽高都调整为32的整数倍，对于其他模型，确保为28的整数倍
        h_bar = math.floor(height / beta / 32) * 32
        w_bar = math.floor(width / beta / 32) * 32
    elif h_bar * w_bar < min_pixels:
        # 计算缩放因子beta，使得缩放后的图像总像素数不低于min_pixels
        beta = math.sqrt(min_pixels / (height * width))
        # 重新计算调整后的高度，对于qwen-vl-max-0815以后及qwen-vl-plus-0815以后更新的模型，确保为32的整数倍，对于其他模型，确保为28的整数倍
        h_bar = math.ceil(height * beta / 32) * 32
        w_bar = math.ceil(width * beta / 32) * 32
    return h_bar, w_bar


def decode_json_points(text: str):
    """解析文本中的坐标点"""
    try:
        # 去除JSON代码块标记
        if "```json" in text:
            text = text.split("```json")[1].split("```")[0]

        # 解析JSON
        data = json.loads(text)
        points = []
        labels = []

        # 遍历所有点并提取坐标与标签
        for item in data:
            if "point_2d" in item:
                x, y = item["point_2d"]
                points.append([x, y])

                # 获取label，如果没有则使用默认值
                label = item.get("label", f"point_{len(points)}")
                labels.append(label)

        return points, labels

    except Exception as e:
        print(f"Error: {e}")
        return [], []


def plot_bounding_boxes(im, bounding_boxes):
    """
        在图像上绘制边界框，并标注名称
    Args:
        img_path: 图像的路径
        bounding_boxes: 包含对象名称的边界框列表，并且位置为标准化的[y1 x1 y2 x2]格式。
    """

    # 加载图像并创建绘图对象
    img = im
    width, height = img.size

    draw = ImageDraw.Draw(img)

    # 定义颜色列表用于区分不同对象
    colors = [
                 'red',
                 'green',
                 'blue',
                 'yellow',
                 'orange',
                 'pink',
                 'purple',
                 'brown',
                 'gray',
                 'beige',
                 'turquoise',
                 'cyan',
                 'magenta',
                 'lime',
                 'navy',
                 'maroon',
                 'teal',
                 'olive',
                 'coral',
                 'lavender',
                 'violet',
                 'gold',
                 'silver',
             ] + additional_colors

    # 解析边界框信息
    bounding_boxes = parse_json(bounding_boxes)

    font = ImageFont.truetype("NotoSansCJK-Regular.ttc", size=25)

    try:
        json_output = ast.literal_eval(bounding_boxes)
    except Exception as e:
        end_idx = bounding_boxes.rfind('"}') + len('"}')
        truncated_text = bounding_boxes[:end_idx] + "]"
        json_output = ast.literal_eval(truncated_text)

    if not isinstance(json_output, list):
        json_output = [json_output]

    # 绘制每个边界框
    for i, bounding_box in enumerate(json_output):

        color = colors[i % len(colors)]

        # 将坐标映射到原图上
        input_height,input_width = smart_resize_qwen2_5_vl(im, min_pixels=32 * 32 * 4,max_pixels=2560 * 32 * 32)
        abs_y1 = int(bounding_box["bbox_2d"][1] / input_height * height)
        abs_x1 = int(bounding_box["bbox_2d"][0] / input_width * width)
        abs_y2 = int(bounding_box["bbox_2d"][3] / input_height * height)
        abs_x2 = int(bounding_box["bbox_2d"][2] / input_width * width)

        if abs_x1 > abs_x2:
            abs_x1, abs_x2 = abs_x2, abs_x1

        if abs_y1 > abs_y2:
            abs_y1, abs_y2 = abs_y2, abs_y1

        # 绘制矩形框
        draw.rectangle(
            ((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=3
        )

        # 添加标签文字
        if "label" in bounding_box:
            draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)

    # 显示最终图像
    img.show()


def plot_points(im, text, input_width, input_height):
    img = im
    width, height = img.size
    draw = ImageDraw.Draw(img)
    colors = [
                 'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'brown', 'gray',
                 'beige', 'turquoise', 'cyan', 'magenta', 'lime', 'navy', 'maroon', 'teal',
                 'olive', 'coral', 'lavender', 'violet', 'gold', 'silver',
             ] + additional_colors
    xml_text = text.replace('```xml', '')
    xml_text = xml_text.replace('```', '')
    data = decode_xml_points(xml_text)
    if data is None:
        img.show()
        return
    points = data['points']
    description = data['phrase']

    font = ImageFont.truetype("/Library/Fonts/NotoSansCJK-Regular.ttf", size=14)

    for i, point in enumerate(points):
        color = colors[i % len(colors)]
        abs_x1 = int(point[0]) / input_width * width
        abs_y1 = int(point[1]) / input_height * height
        radius = 2
        draw.ellipse([(abs_x1 - radius, abs_y1 - radius), (abs_x1 + radius, abs_y1 + radius)], fill=color)
        draw.text((abs_x1 + 8, abs_y1 + 6), description, fill=color, font=font)

    img.show()


# @title 解析JSON输出
def parse_json(json_output):
    # 移除Markdown代码块标记
    lines = json_output.splitlines()
    for i, line in enumerate(lines):
        if line == "```json":
            json_output = "\n".join(lines[i + 1:])  # 删除 "```json"之前的所有内容
            json_output = json_output.split("```")[0]  # 删除 "```"之后的所有内容
            break  # 找到"```json"后退出循环
    return json_output


# 解析 XML 格式的文本，从中提取二维坐标点及相关信息。
def decode_xml_points(text):
    try:
        root = ET.fromstring(text)
        num_points = (len(root.attrib) - 1) // 2
        points = []
        for i in range(num_points):
            x = root.attrib.get(f'x{i+1}')
            y = root.attrib.get(f'y{i+1}')
            points.append([x, y])
        alt = root.attrib.get('alt')
        phrase = root.text.strip() if root.text else None
        return {
            "points": points,
            "alt": alt,
            "phrase": phrase
        }
    except Exception as e:
        print(e)
        return None


# 调用Qwen2.5-VL的 API
def inference_with_api(prompt, sys_prompt="You are a helpful assistant.", model_id="qwen-vl-max",
                       min_pixels=4 * 32 * 32, max_pixels=2560 * 32 * 32):
    client = OpenAI(
        # 若没有配置环境变量，请用阿里云百炼API Key将下行替换为：api_key="sk-xxx",
        api_key=os.getenv("DASHSCOPE_API_KEY"),
        base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    )

    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": sys_prompt}]},
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "min_pixels": min_pixels,
                    "max_pixels": max_pixels,
                    "image_url": {"url": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20251031/dhsvgy/img_2.png"},
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]
    completion = client.chat.completions.create(
        model=model_id,
        messages=messages,

    )
    return completion.choices[0].message.content



def run_object_detection(img_url,model_response):
    # 从URL下载图像
    response = requests.get(img_url)
    response.raise_for_status()
    image = Image.open(BytesIO(response.content))
    # 调用函数绘制边界框
    plot_bounding_boxes(image, model_response)


if __name__ == "__main__":
    url = "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20251031/dhsvgy/img_2.png"
    prompt = """识别图片中的所有食物，并以JSON格式输出其bbox的坐标及其中文名称"""

    # 注意：inference_with_api 函数中的 min_pixels 和 max_pixels 参数值必须与 smart_resize_qwen2_5_vl 函数中的对应参数保持一致，以确保图像处理尺寸的统一性。
    response = inference_with_api(url,  prompt,min_pixels=4 * 32 * 32, max_pixels=2560 * 32 * 32)

    # 调用run_object_detection函数，传入图像URL和模型响应数据
    run_object_detection(url, response)



