|
import _thread as thread |
|
import base64 |
|
import datetime |
|
import hashlib |
|
import hmac |
|
import json |
|
from collections import deque |
|
from urllib.parse import urlparse |
|
import ssl |
|
from datetime import datetime |
|
from time import mktime |
|
from urllib.parse import urlencode |
|
from wsgiref.handlers import format_date_time |
|
from threading import Condition |
|
import websocket |
|
import logging |
|
|
|
from .base_model import BaseLLMModel, CallbackToIterator |
|
|
|
|
|
class Ws_Param(object): |
|
|
|
|
|
def __init__(self, APPID, APIKey, APISecret, Spark_url): |
|
self.APPID = APPID |
|
self.APIKey = APIKey |
|
self.APISecret = APISecret |
|
self.host = urlparse(Spark_url).netloc |
|
self.path = urlparse(Spark_url).path |
|
self.Spark_url = Spark_url |
|
|
|
|
|
def create_url(self): |
|
|
|
now = datetime.now() |
|
date = format_date_time(mktime(now.timetuple())) |
|
|
|
|
|
signature_origin = "host: " + self.host + "\n" |
|
signature_origin += "date: " + date + "\n" |
|
signature_origin += "GET " + self.path + " HTTP/1.1" |
|
|
|
|
|
signature_sha = hmac.new( |
|
self.APISecret.encode("utf-8"), |
|
signature_origin.encode("utf-8"), |
|
digestmod=hashlib.sha256, |
|
).digest() |
|
|
|
signature_sha_base64 = base64.b64encode( |
|
signature_sha).decode(encoding="utf-8") |
|
|
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' |
|
|
|
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( |
|
encoding="utf-8" |
|
) |
|
|
|
|
|
v = {"authorization": authorization, "date": date, "host": self.host} |
|
|
|
url = self.Spark_url + "?" + urlencode(v) |
|
|
|
return url |
|
|
|
|
|
class Spark_Client(BaseLLMModel): |
|
def __init__(self, model_name, appid, api_key, api_secret, user_name="") -> None: |
|
super().__init__(model_name=model_name, user=user_name) |
|
self.api_key = api_key |
|
self.appid = appid |
|
self.api_secret = api_secret |
|
if None in [self.api_key, self.appid, self.api_secret]: |
|
raise Exception("请在配置文件或者环境变量中设置讯飞的API Key、APP ID和API Secret") |
|
if "2.0" in self.model_name: |
|
self.spark_url = "wss://spark-api.xf-yun.com/v2.1/chat" |
|
self.domain = "generalv2" |
|
if "3.0" in self.model_name: |
|
self.spark_url = "wss://spark-api.xf-yun.com/v3.1/chat" |
|
self.domain = "generalv3" |
|
else: |
|
self.spark_url = "wss://spark-api.xf-yun.com/v1.1/chat" |
|
self.domain = "general" |
|
|
|
|
|
def on_error(self, ws, error): |
|
ws.iterator.callback("出现了错误:" + error) |
|
|
|
|
|
def on_close(self, ws, one, two): |
|
pass |
|
|
|
|
|
def on_open(self, ws): |
|
thread.start_new_thread(self.run, (ws,)) |
|
|
|
def run(self, ws, *args): |
|
data = json.dumps( |
|
self.gen_params() |
|
) |
|
ws.send(data) |
|
|
|
|
|
def on_message(self, ws, message): |
|
ws.iterator.callback(message) |
|
|
|
def gen_params(self): |
|
""" |
|
通过appid和用户的提问来生成请参数 |
|
""" |
|
data = { |
|
"header": {"app_id": self.appid, "uid": "1234"}, |
|
"parameter": { |
|
"chat": { |
|
"domain": self.domain, |
|
"random_threshold": self.temperature, |
|
"max_tokens": 4096, |
|
"auditing": "default", |
|
} |
|
}, |
|
"payload": {"message": {"text": self.history}}, |
|
} |
|
return data |
|
|
|
def get_answer_stream_iter(self): |
|
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.spark_url) |
|
websocket.enableTrace(False) |
|
wsUrl = wsParam.create_url() |
|
ws = websocket.WebSocketApp( |
|
wsUrl, |
|
on_message=self.on_message, |
|
on_error=self.on_error, |
|
on_close=self.on_close, |
|
on_open=self.on_open, |
|
) |
|
ws.appid = self.appid |
|
ws.domain = self.domain |
|
|
|
|
|
ws.iterator = CallbackToIterator() |
|
|
|
|
|
thread.start_new_thread( |
|
ws.run_forever, (), {"sslopt": {"cert_reqs": ssl.CERT_NONE}} |
|
) |
|
|
|
|
|
answer = "" |
|
total_tokens = 0 |
|
for message in ws.iterator: |
|
data = json.loads(message) |
|
code = data["header"]["code"] |
|
if code != 0: |
|
ws.close() |
|
raise Exception(f"请求错误: {code}, {data}") |
|
else: |
|
choices = data["payload"]["choices"] |
|
status = choices["status"] |
|
content = choices["text"][0]["content"] |
|
if "usage" in data["payload"]: |
|
total_tokens = data["payload"]["usage"]["text"]["total_tokens"] |
|
answer += content |
|
if status == 2: |
|
ws.iterator.finish() |
|
ws.close() |
|
yield answer, total_tokens |
|
|