Carlos Rosas commited on
Commit
fcce9ea
·
verified ·
1 Parent(s): e33a35c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -75,6 +75,9 @@ class CassandreChatBot:
75
  attention_mask = torch.ones_like(input_ids)
76
 
77
  try:
 
 
 
78
  output = model.generate(
79
  input_ids,
80
  attention_mask=attention_mask,
@@ -85,11 +88,19 @@ class CassandreChatBot:
85
  temperature=temperature,
86
  repetition_penalty=repetition_penalty,
87
  pad_token_id=tokenizer.pad_token_id,
88
- eos_token_id=tokenizer.eos_token_id
 
 
 
89
  )
90
-
91
- # Just decode and display the generated text directly
92
- generated_text = tokenizer.decode(output[0])
 
 
 
 
 
93
  generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
94
  fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
95
  return generated_text, fiches_html
 
75
  attention_mask = torch.ones_like(input_ids)
76
 
77
  try:
78
+ # Add some debug prints
79
+ print("Input length:", len(input_ids[0]))
80
+
81
  output = model.generate(
82
  input_ids,
83
  attention_mask=attention_mask,
 
88
  temperature=temperature,
89
  repetition_penalty=repetition_penalty,
90
  pad_token_id=tokenizer.pad_token_id,
91
+ eos_token_id=tokenizer.eos_token_id,
92
+ # Add return_dict_in_generate=True to see full output info
93
+ return_dict_in_generate=True,
94
+ output_scores=True
95
  )
96
+
97
+ # Print debug info about output
98
+ print("Output sequence length:", len(output.sequences[0]))
99
+ print("New tokens generated:", len(output.sequences[0]) - len(input_ids[0]))
100
+
101
+ # Try decoding only the new tokens
102
+ generated_text = tokenizer.decode(output.sequences[0][len(input_ids[0]):])
103
+
104
  generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
105
  fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
106
  return generated_text, fiches_html