Hjgugugjhuhjggg
commited on
Commit
•
2280244
1
Parent(s):
1429d43
Update app.py
Browse files
app.py
CHANGED
@@ -74,9 +74,10 @@ class GenerateRequest(BaseModel):
|
|
74 |
return v
|
75 |
|
76 |
class StopOnKeywords(StoppingCriteria):
|
77 |
-
def __init__(self, stop_words_ids: List[List[int]], encounters: int = 1):
|
78 |
super().__init__()
|
79 |
self.stop_words_ids = stop_words_ids
|
|
|
80 |
self.encounters = encounters
|
81 |
self.current_encounters = 0
|
82 |
|
@@ -135,7 +136,7 @@ model_loader = GCSModelLoader(bucket)
|
|
135 |
@app.post("/generate")
|
136 |
async def generate(request: GenerateRequest):
|
137 |
model_name = request.model_name
|
138 |
-
input_text = request.input_text
|
139 |
task_type = request.task_type
|
140 |
requested_max_new_tokens = request.max_new_tokens
|
141 |
generation_params = request.model_dump(
|
@@ -155,7 +156,7 @@ async def generate(request: GenerateRequest):
|
|
155 |
|
156 |
if user_defined_stopping_strings:
|
157 |
stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
|
158 |
-
stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
|
159 |
|
160 |
if config.eos_token_id is not None:
|
161 |
eos_token_ids = [config.eos_token_id]
|
@@ -164,13 +165,13 @@ async def generate(request: GenerateRequest):
|
|
164 |
elif isinstance(config.eos_token_id, list):
|
165 |
eos_token_ids = [[id] for id in config.eos_token_id]
|
166 |
stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
|
167 |
-
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
|
168 |
elif tokenizer.eos_token is not None:
|
169 |
stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
|
170 |
-
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
|
171 |
|
172 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
173 |
-
nonlocal input_text
|
174 |
all_generated_text = ""
|
175 |
stop_reason = None
|
176 |
|
@@ -195,22 +196,25 @@ async def generate(request: GenerateRequest):
|
|
195 |
result = await output_queue.get()
|
196 |
thread.join()
|
197 |
|
198 |
-
newly_generated_text = result[0]['generated_text']
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
break
|
202 |
|
203 |
all_generated_text += newly_generated_text
|
204 |
yield {"response": [{'generated_text': newly_generated_text}]}
|
205 |
|
206 |
-
if stopping_criteria_list:
|
207 |
-
for criteria in stopping_criteria_list:
|
208 |
-
if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
|
209 |
-
stop_reason = "stopping_string"
|
210 |
-
break
|
211 |
-
if stop_reason:
|
212 |
-
break
|
213 |
-
|
214 |
if config.eos_token_id is not None:
|
215 |
eos_tokens = [config.eos_token_id]
|
216 |
if isinstance(config.eos_token_id, int):
|
|
|
74 |
return v
|
75 |
|
76 |
class StopOnKeywords(StoppingCriteria):
|
77 |
+
def __init__(self, stop_words_ids: List[List[int]], tokenizer, encounters: int = 1):
|
78 |
super().__init__()
|
79 |
self.stop_words_ids = stop_words_ids
|
80 |
+
self.tokenizer = tokenizer
|
81 |
self.encounters = encounters
|
82 |
self.current_encounters = 0
|
83 |
|
|
|
136 |
@app.post("/generate")
|
137 |
async def generate(request: GenerateRequest):
|
138 |
model_name = request.model_name
|
139 |
+
input_text = request.input_text
|
140 |
task_type = request.task_type
|
141 |
requested_max_new_tokens = request.max_new_tokens
|
142 |
generation_params = request.model_dump(
|
|
|
156 |
|
157 |
if user_defined_stopping_strings:
|
158 |
stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
|
159 |
+
stopping_criteria_list.append(StopOnKeywords(stop_words_ids, tokenizer)) # Pass tokenizer
|
160 |
|
161 |
if config.eos_token_id is not None:
|
162 |
eos_token_ids = [config.eos_token_id]
|
|
|
165 |
elif isinstance(config.eos_token_id, list):
|
166 |
eos_token_ids = [[id] for id in config.eos_token_id]
|
167 |
stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
|
168 |
+
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer
|
169 |
elif tokenizer.eos_token is not None:
|
170 |
stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
|
171 |
+
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer
|
172 |
|
173 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
174 |
+
nonlocal input_text
|
175 |
all_generated_text = ""
|
176 |
stop_reason = None
|
177 |
|
|
|
196 |
result = await output_queue.get()
|
197 |
thread.join()
|
198 |
|
199 |
+
newly_generated_text = result[0]['generated_text']
|
200 |
+
|
201 |
+
# Decode tokens to check for stopping strings
|
202 |
+
for criteria in stopping_criteria_list:
|
203 |
+
if isinstance(criteria, StopOnKeywords):
|
204 |
+
for stop_ids in criteria.stop_words_ids:
|
205 |
+
decoded_stop_string = tokenizer.decode(stop_ids)
|
206 |
+
if decoded_stop_string in newly_generated_text:
|
207 |
+
stop_reason = f"stopping_string: {decoded_stop_string}"
|
208 |
+
break
|
209 |
+
if stop_reason:
|
210 |
+
break
|
211 |
+
|
212 |
+
if stop_reason:
|
213 |
break
|
214 |
|
215 |
all_generated_text += newly_generated_text
|
216 |
yield {"response": [{'generated_text': newly_generated_text}]}
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
if config.eos_token_id is not None:
|
219 |
eos_tokens = [config.eos_token_id]
|
220 |
if isinstance(config.eos_token_id, int):
|