Staticaliza commited on
Commit
7822b29
1 Parent(s): 14026a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -30,8 +30,8 @@ for model_name, model_endpoint in API_ENDPOINTS.items():
30
  def format(instruction, history, input, preoutput):
31
  sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
32
  formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history)
33
- formatted_input = f"{sy_l}System: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}{preoutput}"
34
- return formatted_input
35
 
36
  def predict(instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
37
 
@@ -46,7 +46,7 @@ def predict(instruction, history, input, preoutput, access_key, model, temperatu
46
 
47
  stops = json.loads(stop_seqs)
48
 
49
- formatted_input = format(instruction, history, input, preoutput)
50
 
51
  response = CLIENTS[model].text_generation(
52
  formatted_input,
@@ -64,17 +64,19 @@ def predict(instruction, history, input, preoutput, access_key, model, temperatu
64
  )
65
 
66
  sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
67
- pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}"
68
- pattern = re.compile(f"{sy_l}(.*){sy_r}", re.DOTALL)
69
- match = pattern.search(pre_result)
70
- get_result = preoutput + match.group(1).strip() if match else ""
71
-
72
- history = history + [[input, get_result]]
 
 
73
 
74
- print(formatted_input + get_result)
75
- print(f"---\nUSER: {input}\nBOT: {get_result}\n---")
76
 
77
- return (get_result, input, history)
78
 
79
  def clear_history():
80
  print(">>> HISTORY CLEARED!")
 
30
  def format(instruction, history, input, preoutput):
31
  sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
32
  formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history)
33
+ formatted_input = f"{sy_l}System: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n"
34
+ return formatted_input, formatted_input_base = f"{formatted_input}{sy_l}{preoutput}"
35
 
36
  def predict(instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
37
 
 
46
 
47
  stops = json.loads(stop_seqs)
48
 
49
+ formatted_input, formatted_input_base = format(instruction, history, input, preoutput)
50
 
51
  response = CLIENTS[model].text_generation(
52
  formatted_input,
 
64
  )
65
 
66
  sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
67
+ result = ""
68
+
69
+ for stop in stops:
70
+ result = response.split(stop, 1)[0]
71
+ for symbol in stops:
72
+ result = response.replace(symbol, '')
73
+
74
+ history = history + [[input, result]]
75
 
76
+ print(formatted_input_base + result)
77
+ print(f"---\nUSER: {input}\nBOT: {result}\n---")
78
 
79
+ return (result, input, history)
80
 
81
  def clear_history():
82
  print(">>> HISTORY CLEARED!")