Spaces:
Paused
Paused
import json | |
import os | |
import colorama | |
import requests | |
import logging | |
from modules.models.base_model import BaseLLMModel | |
from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n | |
group_id = os.environ.get("MINIMAX_GROUP_ID", "") | |
class MiniMax_Client(BaseLLMModel): | |
""" | |
MiniMax Client | |
接口文档见 https://api.minimax.chat/document/guides/chat | |
""" | |
def __init__(self, model_name, api_key, user_name="", system_prompt=None): | |
super().__init__(model_name=model_name, user=user_name) | |
self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' | |
self.history = [] | |
self.api_key = api_key | |
self.system_prompt = system_prompt | |
self.headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
def get_answer_at_once(self): | |
# minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert | |
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10 | |
request_body = { | |
"model": self.model_name.replace('minimax-', ''), | |
"temperature": temperature, | |
"skip_info_mask": True, | |
'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}] | |
} | |
if self.n_choices: | |
request_body['beam_width'] = self.n_choices | |
if self.system_prompt: | |
request_body['prompt'] = self.system_prompt | |
if self.max_generation_token: | |
request_body['tokens_to_generate'] = self.max_generation_token | |
if self.top_p: | |
request_body['top_p'] = self.top_p | |
response = requests.post(self.url, headers=self.headers, json=request_body) | |
res = response.json() | |
answer = res['reply'] | |
total_token_count = res["usage"]["total_tokens"] | |
return answer, total_token_count | |
def get_answer_stream_iter(self): | |
response = self._get_response(stream=True) | |
if response is not None: | |
iter = self._decode_chat_response(response) | |
partial_text = "" | |
for i in iter: | |
partial_text += i | |
yield partial_text | |
else: | |
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG | |
def _get_response(self, stream=False): | |
minimax_api_key = self.api_key | |
history = self.history | |
logging.debug(colorama.Fore.YELLOW + | |
f"{history}" + colorama.Fore.RESET) | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {minimax_api_key}", | |
} | |
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10 | |
messages = [] | |
for msg in self.history: | |
if msg['role'] == 'user': | |
messages.append({"sender_type": "USER", "text": msg['content']}) | |
else: | |
messages.append({"sender_type": "BOT", "text": msg['content']}) | |
request_body = { | |
"model": self.model_name.replace('minimax-', ''), | |
"temperature": temperature, | |
"skip_info_mask": True, | |
'messages': messages | |
} | |
if self.n_choices: | |
request_body['beam_width'] = self.n_choices | |
if self.system_prompt: | |
lines = self.system_prompt.splitlines() | |
if lines[0].find(":") != -1 and len(lines[0]) < 20: | |
request_body["role_meta"] = { | |
"user_name": lines[0].split(":")[0], | |
"bot_name": lines[0].split(":")[1] | |
} | |
lines.pop() | |
request_body["prompt"] = "\n".join(lines) | |
if self.max_generation_token: | |
request_body['tokens_to_generate'] = self.max_generation_token | |
else: | |
request_body['tokens_to_generate'] = 512 | |
if self.top_p: | |
request_body['top_p'] = self.top_p | |
if stream: | |
timeout = TIMEOUT_STREAMING | |
request_body['stream'] = True | |
request_body['use_standard_sse'] = True | |
else: | |
timeout = TIMEOUT_ALL | |
try: | |
response = requests.post( | |
self.url, | |
headers=headers, | |
json=request_body, | |
stream=stream, | |
timeout=timeout, | |
) | |
except: | |
return None | |
return response | |
def _decode_chat_response(self, response): | |
error_msg = "" | |
for chunk in response.iter_lines(): | |
if chunk: | |
chunk = chunk.decode() | |
chunk_length = len(chunk) | |
print(chunk) | |
try: | |
chunk = json.loads(chunk[6:]) | |
except json.JSONDecodeError: | |
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}") | |
error_msg += chunk | |
continue | |
if chunk_length > 6 and "delta" in chunk["choices"][0]: | |
if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop": | |
self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts)) | |
break | |
try: | |
yield chunk["choices"][0]["delta"] | |
except Exception as e: | |
logging.error(f"Error: {e}") | |
continue | |
if error_msg: | |
try: | |
error_msg = json.loads(error_msg) | |
if 'base_resp' in error_msg: | |
status_code = error_msg['base_resp']['status_code'] | |
status_msg = error_msg['base_resp']['status_msg'] | |
raise Exception(f"{status_code} - {status_msg}") | |
except json.JSONDecodeError: | |
pass | |
raise Exception(error_msg) | |