Hjgugugjhuhjggg commited on
Commit
2280244
1 Parent(s): 1429d43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -74,9 +74,10 @@ class GenerateRequest(BaseModel):
74
  return v
75
 
76
  class StopOnKeywords(StoppingCriteria):
77
- def __init__(self, stop_words_ids: List[List[int]], encounters: int = 1):
78
  super().__init__()
79
  self.stop_words_ids = stop_words_ids
 
80
  self.encounters = encounters
81
  self.current_encounters = 0
82
 
@@ -135,7 +136,7 @@ model_loader = GCSModelLoader(bucket)
135
  @app.post("/generate")
136
  async def generate(request: GenerateRequest):
137
  model_name = request.model_name
138
- input_text = request.input_text # Initialize input_text here
139
  task_type = request.task_type
140
  requested_max_new_tokens = request.max_new_tokens
141
  generation_params = request.model_dump(
@@ -155,7 +156,7 @@ async def generate(request: GenerateRequest):
155
 
156
  if user_defined_stopping_strings:
157
  stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
158
- stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
159
 
160
  if config.eos_token_id is not None:
161
  eos_token_ids = [config.eos_token_id]
@@ -164,13 +165,13 @@ async def generate(request: GenerateRequest):
164
  elif isinstance(config.eos_token_id, list):
165
  eos_token_ids = [[id] for id in config.eos_token_id]
166
  stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
167
- stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
168
  elif tokenizer.eos_token is not None:
169
  stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
170
- stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
171
 
172
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
173
- nonlocal input_text # Allow modification of the outer scope variable
174
  all_generated_text = ""
175
  stop_reason = None
176
 
@@ -195,22 +196,25 @@ async def generate(request: GenerateRequest):
195
  result = await output_queue.get()
196
  thread.join()
197
 
198
- newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
199
-
200
- if not newly_generated_text:
 
 
 
 
 
 
 
 
 
 
 
201
  break
202
 
203
  all_generated_text += newly_generated_text
204
  yield {"response": [{'generated_text': newly_generated_text}]}
205
 
206
- if stopping_criteria_list:
207
- for criteria in stopping_criteria_list:
208
- if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
209
- stop_reason = "stopping_string"
210
- break
211
- if stop_reason:
212
- break
213
-
214
  if config.eos_token_id is not None:
215
  eos_tokens = [config.eos_token_id]
216
  if isinstance(config.eos_token_id, int):
 
74
  return v
75
 
76
  class StopOnKeywords(StoppingCriteria):
77
+ def __init__(self, stop_words_ids: List[List[int]], tokenizer, encounters: int = 1):
78
  super().__init__()
79
  self.stop_words_ids = stop_words_ids
80
+ self.tokenizer = tokenizer
81
  self.encounters = encounters
82
  self.current_encounters = 0
83
 
 
136
  @app.post("/generate")
137
  async def generate(request: GenerateRequest):
138
  model_name = request.model_name
139
+ input_text = request.input_text
140
  task_type = request.task_type
141
  requested_max_new_tokens = request.max_new_tokens
142
  generation_params = request.model_dump(
 
156
 
157
  if user_defined_stopping_strings:
158
  stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
159
+ stopping_criteria_list.append(StopOnKeywords(stop_words_ids, tokenizer)) # Pass tokenizer
160
 
161
  if config.eos_token_id is not None:
162
  eos_token_ids = [config.eos_token_id]
 
165
  elif isinstance(config.eos_token_id, list):
166
  eos_token_ids = [[id] for id in config.eos_token_id]
167
  stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
168
+ stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer
169
  elif tokenizer.eos_token is not None:
170
  stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
171
+ stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer
172
 
173
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
174
+ nonlocal input_text
175
  all_generated_text = ""
176
  stop_reason = None
177
 
 
196
  result = await output_queue.get()
197
  thread.join()
198
 
199
+ newly_generated_text = result[0]['generated_text']
200
+
201
+ # Decode tokens to check for stopping strings
202
+ for criteria in stopping_criteria_list:
203
+ if isinstance(criteria, StopOnKeywords):
204
+ for stop_ids in criteria.stop_words_ids:
205
+ decoded_stop_string = tokenizer.decode(stop_ids)
206
+ if decoded_stop_string in newly_generated_text:
207
+ stop_reason = f"stopping_string: {decoded_stop_string}"
208
+ break
209
+ if stop_reason:
210
+ break
211
+
212
+ if stop_reason:
213
  break
214
 
215
  all_generated_text += newly_generated_text
216
  yield {"response": [{'generated_text': newly_generated_text}]}
217
 
 
 
 
 
 
 
 
 
218
  if config.eos_token_id is not None:
219
  eos_tokens = [config.eos_token_id]
220
  if isinstance(config.eos_token_id, int):