Florian valade commited on
Commit
a0417ab
·
1 Parent(s): 2fb9772

Update prompt formating

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -82,9 +82,14 @@ def generate_response(message, chat_history, epsilon):
82
  # Set model thresholds based on epsilon
83
  model.head_thresholds = torch.tensor(epsilon_thresholds[epsilon])
84
 
 
 
 
 
 
 
85
  full_response = ""
86
- chat_history = chat_history or []
87
- inputs = tokenizer.encode(message, return_tensors="pt").to(device)
88
 
89
  while not stop_generation:
90
  inputs = truncate_context(inputs)
@@ -115,7 +120,7 @@ def generate_response(message, chat_history, epsilon):
115
 
116
  new_history = chat_history + [(message, full_response)]
117
  yield new_history, new_history, gr.update(value=create_plot())
118
-
119
  def stop_gen():
120
  global stop_generation
121
  stop_generation = True
 
82
  # Set model thresholds based on epsilon
83
  model.head_thresholds = torch.tensor(epsilon_thresholds[epsilon])
84
 
85
+ # Format the prompt with chat history
86
+ formatted_prompt = ""
87
+ for user_msg, assistant_msg in chat_history:
88
+ formatted_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
89
+ formatted_prompt += f"User: {message}\nAssistant:"
90
+
91
  full_response = ""
92
+ inputs = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device)
 
93
 
94
  while not stop_generation:
95
  inputs = truncate_context(inputs)
 
120
 
121
  new_history = chat_history + [(message, full_response)]
122
  yield new_history, new_history, gr.update(value=create_plot())
123
+
124
  def stop_gen():
125
  global stop_generation
126
  stop_generation = True