|
from __future__ import annotations |
|
|
|
import base64 |
|
import json |
|
import logging |
|
import os |
|
import uuid |
|
from io import BytesIO |
|
|
|
import requests |
|
from PIL import Image |
|
|
|
from ..index_func import * |
|
from ..presets import * |
|
from ..utils import * |
|
from .base_model import BaseLLMModel |
|
|
|
|
|
class XMChat(BaseLLMModel): |
|
def __init__(self, api_key, user_name=""): |
|
super().__init__(model_name="xmchat", user=user_name) |
|
self.api_key = api_key |
|
self.session_id = None |
|
self.reset() |
|
self.image_bytes = None |
|
self.image_path = None |
|
self.xm_history = [] |
|
self.url = "https://xmbot.net/web" |
|
self.last_conv_id = None |
|
|
|
def reset(self, remain_system_prompt=False): |
|
self.session_id = str(uuid.uuid4()) |
|
self.last_conv_id = None |
|
return super().reset() |
|
|
|
def image_to_base64(self, image_path): |
|
|
|
img = Image.open(image_path) |
|
|
|
|
|
width, height = img.size |
|
|
|
|
|
max_dimension = 2048 |
|
scale_ratio = min(max_dimension / width, max_dimension / height) |
|
|
|
if scale_ratio < 1: |
|
|
|
new_width = int(width * scale_ratio) |
|
new_height = int(height * scale_ratio) |
|
img = img.resize((new_width, new_height), Image.LANCZOS) |
|
|
|
|
|
buffer = BytesIO() |
|
if img.mode == "RGBA": |
|
img = img.convert("RGB") |
|
img.save(buffer, format='JPEG') |
|
binary_image = buffer.getvalue() |
|
|
|
|
|
base64_image = base64.b64encode(binary_image).decode('utf-8') |
|
|
|
return base64_image |
|
|
|
def try_read_image(self, filepath): |
|
def is_image_file(filepath): |
|
|
|
valid_image_extensions = [ |
|
".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"] |
|
file_extension = os.path.splitext(filepath)[1].lower() |
|
return file_extension in valid_image_extensions |
|
|
|
if is_image_file(filepath): |
|
logging.info(f"读取图片文件: {filepath}") |
|
self.image_bytes = self.image_to_base64(filepath) |
|
self.image_path = filepath |
|
else: |
|
self.image_bytes = None |
|
self.image_path = None |
|
|
|
def like(self): |
|
if self.last_conv_id is None: |
|
return "点赞失败,你还没发送过消息" |
|
data = { |
|
"uuid": self.last_conv_id, |
|
"appraise": "good" |
|
} |
|
requests.post(self.url, json=data) |
|
return "👍点赞成功,感谢反馈~" |
|
|
|
def dislike(self): |
|
if self.last_conv_id is None: |
|
return "点踩失败,你还没发送过消息" |
|
data = { |
|
"uuid": self.last_conv_id, |
|
"appraise": "bad" |
|
} |
|
requests.post(self.url, json=data) |
|
return "👎点踩成功,感谢反馈~" |
|
|
|
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot): |
|
fake_inputs = real_inputs |
|
display_append = "" |
|
limited_context = False |
|
return limited_context, fake_inputs, display_append, real_inputs, chatbot |
|
|
|
def handle_file_upload(self, files, chatbot, language): |
|
"""if the model accepts multi modal input, implement this function""" |
|
if files: |
|
for file in files: |
|
if file.name: |
|
logging.info(f"尝试读取图像: {file.name}") |
|
self.try_read_image(file.name) |
|
if self.image_path is not None: |
|
chatbot = chatbot + [((self.image_path,), None)] |
|
if self.image_bytes is not None: |
|
logging.info("使用图片作为输入") |
|
|
|
self.reset() |
|
conv_id = str(uuid.uuid4()) |
|
data = { |
|
"user_id": self.api_key, |
|
"session_id": self.session_id, |
|
"uuid": conv_id, |
|
"data_type": "imgbase64", |
|
"data": self.image_bytes |
|
} |
|
response = requests.post(self.url, json=data) |
|
response = json.loads(response.text) |
|
logging.info(f"图片回复: {response['data']}") |
|
return None, chatbot, None |
|
|
|
def get_answer_at_once(self): |
|
question = self.history[-1]["content"] |
|
conv_id = str(uuid.uuid4()) |
|
self.last_conv_id = conv_id |
|
data = { |
|
"user_id": self.api_key, |
|
"session_id": self.session_id, |
|
"uuid": conv_id, |
|
"data_type": "text", |
|
"data": question |
|
} |
|
response = requests.post(self.url, json=data) |
|
try: |
|
response = json.loads(response.text) |
|
return response["data"], len(response["data"]) |
|
except Exception as e: |
|
return response.text, len(response.text) |
|
|