3v324v23 commited on
Commit
bfa6661
·
1 Parent(s): d79dfe2
Files changed (3) hide show
  1. config.py +1 -1
  2. predict.py +6 -2
  3. request_llm/bridge_tgui.py +137 -0
config.py CHANGED
@@ -34,7 +34,7 @@ WEB_PORT = -1
34
  MAX_RETRY = 2
35
 
36
  # OpenAI模型选择是(gpt4现在只对申请成功的人开放)
37
- LLM_MODEL = "gpt-3.5-turbo"
38
 
39
  # OpenAI的API_URL
40
  API_URL = "https://api.openai.com/v1/chat/completions"
 
34
  MAX_RETRY = 2
35
 
36
  # OpenAI模型选择是(gpt4现在只对申请成功的人开放)
37
+ LLM_MODEL = "pygmalion-1.3b@localhost@7860" # "gpt-3.5-turbo"
38
 
39
  # OpenAI的API_URL
40
  API_URL = "https://api.openai.com/v1/chat/completions"
predict.py CHANGED
@@ -112,8 +112,7 @@ def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_pr
112
  return result
113
 
114
 
115
- def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='',
116
- stream = True, additional_fn=None):
117
  """
118
  发送至chatGPT,流式获取输出。
119
  用于基础的对话功能。
@@ -244,3 +243,8 @@ def generate_payload(inputs, top_p, temperature, history, system_prompt, stream)
244
  return headers,payload
245
 
246
 
 
 
 
 
 
 
112
  return result
113
 
114
 
115
+ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None):
 
116
  """
117
  发送至chatGPT,流式获取输出。
118
  用于基础的对话功能。
 
243
  return headers,payload
244
 
245
 
246
+ if not LLM_MODEL.startswith('gpt'):
247
+ from request_llm.bridge_tgui import predict_tgui
248
+ predict = predict_tgui
249
+
250
+
request_llm/bridge_tgui.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Contributed by SagsMug. Modified by binary-husky
3
+ https://github.com/oobabooga/text-generation-webui/pull/175
4
+ '''
5
+
6
+ import asyncio
7
+ import json
8
+ import random
9
+ import string
10
+ import websockets
11
+ import logging
12
+ import time
13
+ import threading
14
+ from toolbox import get_conf
15
+ LLM_MODEL, = get_conf('LLM_MODEL')
16
+
17
+ model_name, addr, port = LLM_MODEL.split('@')
18
+
19
+ def random_hash():
20
+ letters = string.ascii_lowercase + string.digits
21
+ return ''.join(random.choice(letters) for i in range(9))
22
+
23
+ async def run(context):
24
+ params = {
25
+ 'max_new_tokens': 200,
26
+ 'do_sample': True,
27
+ 'temperature': 0.5,
28
+ 'top_p': 0.9,
29
+ 'typical_p': 1,
30
+ 'repetition_penalty': 1.05,
31
+ 'encoder_repetition_penalty': 1.0,
32
+ 'top_k': 0,
33
+ 'min_length': 0,
34
+ 'no_repeat_ngram_size': 0,
35
+ 'num_beams': 1,
36
+ 'penalty_alpha': 0,
37
+ 'length_penalty': 1,
38
+ 'early_stopping': False,
39
+ 'seed': -1,
40
+ }
41
+ session = random_hash()
42
+
43
+ async with websockets.connect(f"ws://{addr}:{port}/queue/join") as websocket:
44
+ while content := json.loads(await websocket.recv()):
45
+ #Python3.10 syntax, replace with if elif on older
46
+ if content["msg"] == "send_hash":
47
+ await websocket.send(json.dumps({
48
+ "session_hash": session,
49
+ "fn_index": 12
50
+ }))
51
+ elif content["msg"] == "estimation":
52
+ pass
53
+ elif content["msg"] == "send_data":
54
+ await websocket.send(json.dumps({
55
+ "session_hash": session,
56
+ "fn_index": 12,
57
+ "data": [
58
+ context,
59
+ params['max_new_tokens'],
60
+ params['do_sample'],
61
+ params['temperature'],
62
+ params['top_p'],
63
+ params['typical_p'],
64
+ params['repetition_penalty'],
65
+ params['encoder_repetition_penalty'],
66
+ params['top_k'],
67
+ params['min_length'],
68
+ params['no_repeat_ngram_size'],
69
+ params['num_beams'],
70
+ params['penalty_alpha'],
71
+ params['length_penalty'],
72
+ params['early_stopping'],
73
+ params['seed'],
74
+ ]
75
+ }))
76
+ elif content["msg"] == "process_starts":
77
+ pass
78
+ elif content["msg"] in ["process_generating", "process_completed"]:
79
+ yield content["output"]["data"][0]
80
+ # You can search for your desired end indicator and
81
+ # stop generation by closing the websocket here
82
+ if (content["msg"] == "process_completed"):
83
+ break
84
+
85
+
86
+
87
+
88
+
89
+ def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None):
90
+ """
91
+ 发送至chatGPT,流式获取输出。
92
+ 用于基础的对话功能。
93
+ inputs 是本次问询的输入
94
+ top_p, temperature是chatGPT的内部调优参数
95
+ history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
96
+ chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
97
+ additional_fn代表点击的哪个按钮,按钮见functional.py
98
+ """
99
+ if additional_fn is not None:
100
+ import functional
101
+ importlib.reload(functional) # 热更新prompt
102
+ functional = functional.get_functionals()
103
+ if "PreProcess" in functional[additional_fn]: inputs = functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话)
104
+ inputs = functional[additional_fn]["Prefix"] + inputs + functional[additional_fn]["Suffix"]
105
+
106
+ raw_input = inputs
107
+ logging.info(f'[raw_input] {raw_input}')
108
+ chatbot.append((inputs, ""))
109
+ yield chatbot, history, "等待响应"
110
+
111
+ prompt = inputs
112
+ tgui_say = ""
113
+
114
+ mutable = [""]
115
+ def run_coorotine(mutable):
116
+ async def get_result():
117
+ async for response in run(prompt):
118
+ # Print intermediate steps
119
+ mutable += response
120
+ asyncio.run(get_result())
121
+
122
+ thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))
123
+ thread_listen.start()
124
+
125
+ while thread_listen.is_alive():
126
+ time.sleep(1)
127
+ # Print intermediate steps
128
+ if tgui_say != mutable[0]:
129
+ tgui_say = mutable[0]
130
+ history[-1] = tgui_say
131
+ chatbot[-1] = (history[-2], history[-1])
132
+ yield chatbot, history, status_text
133
+
134
+ logging.info(f'[response] {tgui_say}')
135
+
136
+
137
+