Spaces:
Running
Running
from __future__ import annotations | |
import base64 | |
import json | |
import logging | |
import os | |
import uuid | |
from io import BytesIO | |
import requests | |
from PIL import Image | |
from ..index_func import * | |
from ..presets import * | |
from ..utils import * | |
from .base_model import BaseLLMModel | |
class XMChat(BaseLLMModel): | |
def __init__(self, api_key, user_name=""): | |
super().__init__(model_name="xmchat", user=user_name) | |
self.api_key = api_key | |
self.session_id = None | |
self.reset() | |
self.image_bytes = None | |
self.image_path = None | |
self.xm_history = [] | |
self.url = "https://xmbot.net/web" | |
self.last_conv_id = None | |
def reset(self, remain_system_prompt=False): | |
self.session_id = str(uuid.uuid4()) | |
self.last_conv_id = None | |
return super().reset() | |
def image_to_base64(self, image_path): | |
# 打开并加载图片 | |
img = Image.open(image_path) | |
# 获取图片的宽度和高度 | |
width, height = img.size | |
# 计算压缩比例,以确保最长边小于4096像素 | |
max_dimension = 2048 | |
scale_ratio = min(max_dimension / width, max_dimension / height) | |
if scale_ratio < 1: | |
# 按压缩比例调整图片大小 | |
new_width = int(width * scale_ratio) | |
new_height = int(height * scale_ratio) | |
img = img.resize((new_width, new_height), Image.LANCZOS) | |
# 将图片转换为jpg格式的二进制数据 | |
buffer = BytesIO() | |
if img.mode == "RGBA": | |
img = img.convert("RGB") | |
img.save(buffer, format='JPEG') | |
binary_image = buffer.getvalue() | |
# 对二进制数据进行Base64编码 | |
base64_image = base64.b64encode(binary_image).decode('utf-8') | |
return base64_image | |
def try_read_image(self, filepath): | |
def is_image_file(filepath): | |
# 判断文件是否为图片 | |
valid_image_extensions = [ | |
".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"] | |
file_extension = os.path.splitext(filepath)[1].lower() | |
return file_extension in valid_image_extensions | |
if is_image_file(filepath): | |
logging.info(f"读取图片文件: {filepath}") | |
self.image_bytes = self.image_to_base64(filepath) | |
self.image_path = filepath | |
else: | |
self.image_bytes = None | |
self.image_path = None | |
def like(self): | |
if self.last_conv_id is None: | |
return "点赞失败,你还没发送过消息" | |
data = { | |
"uuid": self.last_conv_id, | |
"appraise": "good" | |
} | |
requests.post(self.url, json=data) | |
return "👍点赞成功,感谢反馈~" | |
def dislike(self): | |
if self.last_conv_id is None: | |
return "点踩失败,你还没发送过消息" | |
data = { | |
"uuid": self.last_conv_id, | |
"appraise": "bad" | |
} | |
requests.post(self.url, json=data) | |
return "👎点踩成功,感谢反馈~" | |
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot): | |
fake_inputs = real_inputs | |
display_append = "" | |
limited_context = False | |
return limited_context, fake_inputs, display_append, real_inputs, chatbot | |
def handle_file_upload(self, files, chatbot, language): | |
"""if the model accepts multi modal input, implement this function""" | |
if files: | |
for file in files: | |
if file.name: | |
logging.info(f"尝试读取图像: {file.name}") | |
self.try_read_image(file.name) | |
if self.image_path is not None: | |
chatbot = chatbot + [((self.image_path,), None)] | |
if self.image_bytes is not None: | |
logging.info("使用图片作为输入") | |
# XMChat的一轮对话中实际上只能处理一张图片 | |
self.reset() | |
conv_id = str(uuid.uuid4()) | |
data = { | |
"user_id": self.api_key, | |
"session_id": self.session_id, | |
"uuid": conv_id, | |
"data_type": "imgbase64", | |
"data": self.image_bytes | |
} | |
response = requests.post(self.url, json=data) | |
response = json.loads(response.text) | |
logging.info(f"图片回复: {response['data']}") | |
return None, chatbot, None | |
def get_answer_at_once(self): | |
question = self.history[-1]["content"] | |
conv_id = str(uuid.uuid4()) | |
self.last_conv_id = conv_id | |
data = { | |
"user_id": self.api_key, | |
"session_id": self.session_id, | |
"uuid": conv_id, | |
"data_type": "text", | |
"data": question | |
} | |
response = requests.post(self.url, json=data) | |
try: | |
response = json.loads(response.text) | |
return response["data"], len(response["data"]) | |
except Exception as e: | |
return response.text, len(response.text) | |