Carlos Rosas commited on
Commit
52e369f
·
verified ·
1 Parent(s): fcce9ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -14
app.py CHANGED
@@ -75,9 +75,6 @@ class CassandreChatBot:
75
  attention_mask = torch.ones_like(input_ids)
76
 
77
  try:
78
- # Add some debug prints
79
- print("Input length:", len(input_ids[0]))
80
-
81
  output = model.generate(
82
  input_ids,
83
  attention_mask=attention_mask,
@@ -88,18 +85,11 @@ class CassandreChatBot:
88
  temperature=temperature,
89
  repetition_penalty=repetition_penalty,
90
  pad_token_id=tokenizer.pad_token_id,
91
- eos_token_id=tokenizer.eos_token_id,
92
- # Add return_dict_in_generate=True to see full output info
93
- return_dict_in_generate=True,
94
- output_scores=True
95
  )
96
-
97
- # Print debug info about output
98
- print("Output sequence length:", len(output.sequences[0]))
99
- print("New tokens generated:", len(output.sequences[0]) - len(input_ids[0]))
100
-
101
- # Try decoding only the new tokens
102
- generated_text = tokenizer.decode(output.sequences[0][len(input_ids[0]):])
103
 
104
  generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
105
  fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
 
75
  attention_mask = torch.ones_like(input_ids)
76
 
77
  try:
 
 
 
78
  output = model.generate(
79
  input_ids,
80
  attention_mask=attention_mask,
 
85
  temperature=temperature,
86
  repetition_penalty=repetition_penalty,
87
  pad_token_id=tokenizer.pad_token_id,
88
+ eos_token_id=tokenizer.eos_token_id
 
 
 
89
  )
90
+
91
+ # Only decode the new tokens by slicing from the input length
92
+ generated_text = tokenizer.decode(output[0][len(input_ids[0]):])
 
 
 
 
93
 
94
  generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
95
  fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html