qingxu99 commited on
Commit
676fe40
1 Parent(s): 0b89673

优化chatgpt对话的截断策略

Browse files
crazy_functions/谷歌检索小助手.py CHANGED
@@ -98,7 +98,8 @@ def 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
98
  history.extend([ "第一批", gpt_say ])
99
  meta_paper_info_list = meta_paper_info_list[10:]
100
 
101
- chatbot.append(["状态?", "已经全部完成"])
 
102
  msg = '正常'
103
  yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
104
  res = write_results_to_file(history)
 
98
  history.extend([ "第一批", gpt_say ])
99
  meta_paper_info_list = meta_paper_info_list[10:]
100
 
101
+ chatbot.append(["状态?",
102
+ "已经全部完成,您可以试试让AI写一个Related Works,例如您可以继续输入Write a \"Related Works\" section about \"你搜索的研究领域\" for me."])
103
  msg = '正常'
104
  yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
105
  res = write_results_to_file(history)
request_llm/bridge_chatgpt.py CHANGED
@@ -21,7 +21,7 @@ import importlib
21
 
22
  # config_private.py放自己的秘密如API和代理网址
23
  # 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
24
- from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys
25
  proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
26
  get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
27
 
@@ -145,7 +145,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
145
  yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
146
  return
147
 
148
- history.append(inputs); history.append(" ")
149
 
150
  retry = 0
151
  while True:
@@ -198,14 +198,17 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
198
  chunk_decoded = chunk.decode()
199
  error_msg = chunk_decoded
200
  if "reduce the length" in error_msg:
201
- chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长,或历史数据过长. 历史缓存数据现已释放,您可以请再次尝试.")
202
- history = [] # 清除历史
 
 
 
203
  elif "does not exist" in error_msg:
204
- chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在,或者您没有获得体验资格.")
205
  elif "Incorrect API key" in error_msg:
206
- chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由,拒绝服务.")
207
  elif "exceeded your current quota" in error_msg:
208
- chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由,拒绝服务.")
209
  elif "bad forward key" in error_msg:
210
  chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
211
  elif "Not enough point" in error_msg:
 
21
 
22
  # config_private.py放自己的秘密如API和代理网址
23
  # 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
24
+ from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
25
  proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
26
  get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
27
 
 
145
  yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
146
  return
147
 
148
+ history.append(inputs); history.append("")
149
 
150
  retry = 0
151
  while True:
 
198
  chunk_decoded = chunk.decode()
199
  error_msg = chunk_decoded
200
  if "reduce the length" in error_msg:
201
+ if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
202
+ history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
203
+ max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一
204
+ chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
205
+ # history = [] # 清除历史
206
  elif "does not exist" in error_msg:
207
+ chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.")
208
  elif "Incorrect API key" in error_msg:
209
+ chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务.")
210
  elif "exceeded your current quota" in error_msg:
211
+ chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务.")
212
  elif "bad forward key" in error_msg:
213
  chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
214
  elif "Not enough point" in error_msg:
toolbox.py CHANGED
@@ -551,3 +551,49 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
551
  return {"message": f"Gradio is running at: {custom_path}"}
552
  app = gr.mount_gradio_app(app, demo, path=custom_path)
553
  uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  return {"message": f"Gradio is running at: {custom_path}"}
552
  app = gr.mount_gradio_app(app, demo, path=custom_path)
553
  uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
554
+
555
+
556
+ def clip_history(inputs, history, tokenizer, max_token_limit):
557
+ """
558
+ reduce the length of input/history by clipping.
559
+ this function search for the longest entries to clip, little by little,
560
+ until the number of token of input/history is reduced under threshold.
561
+ 通过剪辑来缩短输入/历史记录的长度。
562
+ 此函数逐渐地搜索最长的条目进行剪辑,
563
+ 直到输入/历史记录的标记数量降低到阈值以下。
564
+ """
565
+ import numpy as np
566
+ from request_llm.bridge_all import model_info
567
+ def get_token_num(txt):
568
+ return len(tokenizer.encode(txt, disallowed_special=()))
569
+ input_token_num = get_token_num(inputs)
570
+ if input_token_num < max_token_limit * 3 / 4:
571
+ # 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来
572
+ max_token_limit = max_token_limit - input_token_num
573
+ if max_token_limit < 128:
574
+ # 余量太小了,直接清除历史
575
+ history = []
576
+ return history
577
+ else:
578
+ # 当输入部分的token占比 > 限制的3/4时,直接清除历史
579
+ history = []
580
+ return history
581
+
582
+ everything = ['']
583
+ everything.extend(history)
584
+ n_token = get_token_num('\n'.join(everything))
585
+ everything_token = [get_token_num(e) for e in everything]
586
+
587
+ # 截断时的颗粒度
588
+ delta = max(everything_token) // 16
589
+
590
+ while n_token > max_token_limit:
591
+ where = np.argmax(everything_token)
592
+ encoded = tokenizer.encode(everything[where], disallowed_special=())
593
+ clipped_encoded = encoded[:len(encoded)-delta]
594
+ everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
595
+ everything_token[where] = get_token_num(everything[where])
596
+ n_token = get_token_num('\n'.join(everything))
597
+
598
+ history = everything[1:]
599
+ return history