Woziii commited on
Commit
63afc3f
1 Parent(s): 3226776

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -28
app.py CHANGED
@@ -62,17 +62,18 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
62
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
63
  top_k = 5
64
  top_probs, top_indices = torch.topk(probabilities, top_k)
65
- top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
66
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
67
  prob_plot = plot_probabilities(prob_data)
68
 
 
 
 
 
69
  if hasattr(outputs, 'attentions') and outputs.attentions is not None:
70
- attention_data = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
71
- attention_plot = plot_attention(attention_data, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
72
- else:
73
- attention_plot = None
74
 
75
- return "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()]), attention_plot, prob_plot
76
  except Exception as e:
77
  return f"Erreur lors de l'analyse : {str(e)}", None, None
78
 
@@ -88,29 +89,19 @@ def generate_text(input_text, temperature, top_p, top_k):
88
  with torch.no_grad():
89
  outputs = model.generate(
90
  **inputs,
91
- max_new_tokens=50,
92
  temperature=temperature,
93
  top_p=top_p,
94
  top_k=top_k
95
  )
96
 
97
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
- return generated_text
 
 
99
  except Exception as e:
100
  return f"Erreur lors de la génération : {str(e)}"
101
 
102
- def plot_attention(attention, tokens):
103
- fig, ax = plt.subplots(figsize=(10, 10))
104
- im = ax.imshow(attention, cmap='viridis')
105
- ax.set_xticks(range(len(tokens)))
106
- ax.set_yticks(range(len(tokens)))
107
- ax.set_xticklabels(tokens, rotation=90)
108
- ax.set_yticklabels(tokens)
109
- plt.colorbar(im)
110
- plt.title("Carte d'attention")
111
- plt.tight_layout()
112
- return fig
113
-
114
  def plot_probabilities(prob_data):
115
  words = list(prob_data.keys())
116
  probs = list(prob_data.values())
@@ -145,27 +136,26 @@ with gr.Blocks() as demo:
145
 
146
  input_text = gr.Textbox(label="Texte d'entrée", lines=3)
147
  analyze_button = gr.Button("Analyser le prochain token")
148
- generate_button = gr.Button("Générer la suite du texte")
149
 
150
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
 
151
 
152
- with gr.Row():
153
- attention_plot = gr.Plot(label="Visualisation de l'attention")
154
- prob_plot = gr.Plot(label="Probabilités des tokens suivants")
155
 
156
- generated_text = gr.Textbox(label="Texte généré", lines=5)
 
157
 
158
  reset_button = gr.Button("Réinitialiser")
159
 
160
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
161
  analyze_button.click(analyze_next_token,
162
  inputs=[input_text, temperature, top_p, top_k],
163
- outputs=[next_token_probs, attention_plot, prob_plot])
164
  generate_button.click(generate_text,
165
  inputs=[input_text, temperature, top_p, top_k],
166
- outputs=[generated_text])
167
  reset_button.click(reset,
168
- outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
169
 
170
  if __name__ == "__main__":
171
  demo.launch()
 
62
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
63
  top_k = 5
64
  top_probs, top_indices = torch.topk(probabilities, top_k)
65
+ top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
66
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
67
  prob_plot = plot_probabilities(prob_data)
68
 
69
+ prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
70
+
71
+ # Simplification de l'affichage de l'attention
72
+ attention_text = "Attention non disponible pour ce modèle"
73
  if hasattr(outputs, 'attentions') and outputs.attentions is not None:
74
+ attention_text = "Attention disponible"
 
 
 
75
 
76
+ return prob_text, attention_text, prob_plot
77
  except Exception as e:
78
  return f"Erreur lors de l'analyse : {str(e)}", None, None
79
 
 
89
  with torch.no_grad():
90
  outputs = model.generate(
91
  **inputs,
92
+ max_new_tokens=1, # Génère seulement le prochain mot
93
  temperature=temperature,
94
  top_p=top_p,
95
  top_k=top_k
96
  )
97
 
98
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+ # Ne retourne que le nouveau mot généré
100
+ new_word = generated_text[len(input_text):].strip()
101
+ return new_word
102
  except Exception as e:
103
  return f"Erreur lors de la génération : {str(e)}"
104
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def plot_probabilities(prob_data):
106
  words = list(prob_data.keys())
107
  probs = list(prob_data.values())
 
136
 
137
  input_text = gr.Textbox(label="Texte d'entrée", lines=3)
138
  analyze_button = gr.Button("Analyser le prochain token")
 
139
 
140
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
141
+ attention_info = gr.Textbox(label="Information sur l'attention")
142
 
143
+ prob_plot = gr.Plot(label="Probabilités des tokens suivants")
 
 
144
 
145
+ generate_button = gr.Button("Générer le prochain mot")
146
+ generated_word = gr.Textbox(label="Mot généré")
147
 
148
  reset_button = gr.Button("Réinitialiser")
149
 
150
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
151
  analyze_button.click(analyze_next_token,
152
  inputs=[input_text, temperature, top_p, top_k],
153
+ outputs=[next_token_probs, attention_info, prob_plot])
154
  generate_button.click(generate_text,
155
  inputs=[input_text, temperature, top_p, top_k],
156
+ outputs=[generated_word])
157
  reset_button.click(reset,
158
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_info, prob_plot, generated_word])
159
 
160
  if __name__ == "__main__":
161
  demo.launch()