mateoluksenberg commited on
Commit
3954ce3
·
verified ·
1 Parent(s): a025e11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -42
app.py CHANGED
@@ -212,49 +212,46 @@ EXAMPLES = [
212
  # Definir la función simple_chat
213
  @spaces.GPU()
214
  def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
215
- # Cargar el modelo preentrenado
216
- model = AutoModelForCausalLM.from_pretrained(
217
- MODEL_ID,
218
- torch_dtype=torch.bfloat16,
219
- low_cpu_mem_usage=True,
220
- trust_remote_code=True
221
- )
222
-
223
- conversation = []
224
-
225
- if "file" in message and message["file"]:
226
- file_path = message["file"]
227
- choice, contents = mode_load(file_path)
228
- if choice == "image":
229
- conversation.append({"role": "user", "image": contents, "content": message["text"]})
230
- elif choice == "doc":
231
- format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message["text"]
232
- conversation.append({"role": "user", "content": format_msg})
233
- else:
234
- conversation.append({"role": "user", "content": message["text"]})
235
-
236
- # Preparar entrada para el modelo
237
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
238
- return_tensors="pt", return_dict=True).to(model.device)
239
-
240
- # Configurar parámetros de generación
241
- generate_kwargs = dict(
242
- max_length=max_length,
243
- do_sample=True,
244
- top_p=top_p,
245
- top_k=top_k,
246
- temperature=temperature,
247
- repetition_penalty=penalty,
248
- eos_token_id=[151329, 151336, 151338],
249
- )
250
-
251
- # Generar respuesta
252
- with torch.no_grad():
253
- generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
254
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
255
 
256
- # Devolver la respuesta completa
257
- return PlainTextResponse(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  @app.post("/chat/")
260
  async def test_endpoint(message: dict):
 
212
  # Definir la función simple_chat
213
  @spaces.GPU()
214
  def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
215
+ try:
216
+ model = AutoModelForCausalLM.from_pretrained(
217
+ MODEL_ID,
218
+ torch_dtype=torch.bfloat16,
219
+ low_cpu_mem_usage=True,
220
+ trust_remote_code=True
221
+ )
222
+
223
+ conversation = []
224
+
225
+ if "file" in message and message["file"]:
226
+ file_path = message["file"]
227
+ choice, contents = mode_load(file_path)
228
+ if choice == "image":
229
+ conversation.append({"role": "user", "image": contents, "content": message["text"]})
230
+ elif choice == "doc":
231
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message["text"]
232
+ conversation.append({"role": "user", "content": format_msg})
233
+ else:
234
+ conversation.append({"role": "user", "content": message["text"]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
237
+
238
+ generate_kwargs = dict(
239
+ max_length=max_length,
240
+ do_sample=True,
241
+ top_p=top_p,
242
+ top_k=top_k,
243
+ temperature=temperature,
244
+ repetition_penalty=penalty,
245
+ eos_token_id=[151329, 151336, 151338],
246
+ )
247
+
248
+ with torch.no_grad():
249
+ generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
250
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
251
+
252
+ return PlainTextResponse(generated_text)
253
+ except Exception as e:
254
+ return PlainTextResponse(f"Error: {str(e)}")
255
 
256
  @app.post("/chat/")
257
  async def test_endpoint(message: dict):