dreamerdeo commited on
Commit
07c2aaa
1 Parent(s): dbc2154

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -42,20 +42,31 @@ Your responses should be friendly, unbiased, informative, detailed, and faithful
42
  system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>"
43
 
44
  # Function to generate model predictions.
45
-
46
  @spaces.GPU()
47
  def predict(message, history):
48
- # history = []
 
 
 
 
49
  history_transformer_format = history + [[message, ""]]
50
  stop = StopOnTokens()
51
 
52
- # Formatting the input for the model.
53
- messages = system_prompt + sft_end_token.join([sft_end_token.join([f"\n{sft_start_token}{user_role}\n" + item[0], f"\n{sft_start_token}{assistant_role}\n" + item[1]])
54
- for item in history_transformer_format])
 
 
 
 
 
 
 
55
  model_inputs = tokenizer([messages], return_tensors="pt").to(device)
56
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
57
  generate_kwargs = dict(
58
- model_inputs,
 
59
  streamer=streamer,
60
  max_new_tokens=1024,
61
  do_sample=True,
@@ -66,14 +77,23 @@ def predict(message, history):
66
  stopping_criteria=StoppingCriteriaList([stop]),
67
  repetition_penalty=1.1,
68
  )
 
 
69
  t = Thread(target=model.generate, kwargs=generate_kwargs)
70
- t.start() # Starting the generation in a separate thread.
 
 
71
  partial_message = ""
72
  for new_token in streamer:
73
  partial_message += new_token
74
- if sft_end_token in partial_message: # Breaking the loop if the stop token is generated.
75
  break
76
- yield partial_message
 
 
 
 
 
77
 
78
 
79
  css = """
 
42
  system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>"
43
 
44
  # Function to generate model predictions.
 
45
  @spaces.GPU()
46
  def predict(message, history):
47
+ # 初始化对话历史格式
48
+ if history is None:
49
+ history = []
50
+
51
+ # 在历史中添加当前用户输入,临时设置机器人的回复为空
52
  history_transformer_format = history + [[message, ""]]
53
  stop = StopOnTokens()
54
 
55
+ # 格式化输入为模型需要的格式
56
+ messages = (
57
+ system_prompt
58
+ + sft_end_token.join([
59
+ sft_end_token.join([
60
+ f"\n{sft_start_token}{user_role}\n" + item[0],
61
+ f"\n{sft_start_token}{assistant_role}\n" + item[1]
62
+ ]) for item in history_transformer_format
63
+ ])
64
+ )
65
  model_inputs = tokenizer([messages], return_tensors="pt").to(device)
66
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
67
  generate_kwargs = dict(
68
+ input_ids=model_inputs["input_ids"],
69
+ attention_mask=model_inputs["attention_mask"],
70
  streamer=streamer,
71
  max_new_tokens=1024,
72
  do_sample=True,
 
77
  stopping_criteria=StoppingCriteriaList([stop]),
78
  repetition_penalty=1.1,
79
  )
80
+
81
+ # 使用线程来运行生成过程
82
  t = Thread(target=model.generate, kwargs=generate_kwargs)
83
+ t.start()
84
+
85
+ # 实时生成部分消息
86
  partial_message = ""
87
  for new_token in streamer:
88
  partial_message += new_token
89
+ if sft_end_token in partial_message: # 检测到停止标志
90
  break
91
+ yield history + [[message, partial_message]] # 输出流式数据
92
+
93
+ # 处理生成的最终回复
94
+ final_message = partial_message.replace(sft_end_token, "").strip()
95
+ history.append([message, final_message]) # 更新历史记录
96
+ yield history # 返回完整对话历史
97
 
98
 
99
  css = """