arjunanand13 commited on
Commit
f834e93
1 Parent(s): c9a9aee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -95,7 +95,7 @@ class DocumentRetrievalAndGeneration:
95
  messages = [{"role": "user", "content": prompt}]
96
  encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
97
  model_inputs = encodeds.to(self.llm.device)
98
-
99
  # Perform inference and measure time
100
  start_time = datetime.now()
101
  generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
@@ -104,11 +104,22 @@ class DocumentRetrievalAndGeneration:
104
  # Decode and return output
105
  decoded = self.llm.tokenizer.batch_decode(generated_ids)
106
  generated_response = decoded[0]
 
 
 
 
 
 
 
 
 
 
 
107
  print("Generated response:", generated_response)
108
  print("Time elapsed:", elapsed_time)
109
  print("Device in use:", self.llm.device)
110
 
111
- return generated_response, content
112
 
113
  def qa_infer_gradio(self, query):
114
  response = self.query_and_generate_response(query)
 
95
  messages = [{"role": "user", "content": prompt}]
96
  encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
97
  model_inputs = encodeds.to(self.llm.device)
98
+
99
  # Perform inference and measure time
100
  start_time = datetime.now()
101
  generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
 
104
  # Decode and return output
105
  decoded = self.llm.tokenizer.batch_decode(generated_ids)
106
  generated_response = decoded[0]
107
+ match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
108
+ match2 = re.search(r'Solution:(.*?)</s>', text, re.DOTALL | re.IGNORECASE)
109
+ if match1:
110
+ solution_text = match1.group(1).strip()
111
+ print(solution_text)
112
+ elif match2:
113
+ solution_text = match2.group(1).strip()
114
+ print(solution_text)
115
+
116
+ else:
117
+ solution_text=generated_response
118
  print("Generated response:", generated_response)
119
  print("Time elapsed:", elapsed_time)
120
  print("Device in use:", self.llm.device)
121
 
122
+ return solution_text, content
123
 
124
  def qa_infer_gradio(self, query):
125
  response = self.query_and_generate_response(query)