Woziii commited on
Commit
e1ef0ab
·
verified ·
1 Parent(s): 391d3d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -21
app.py CHANGED
@@ -66,17 +66,23 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
69
- top_k = 5
 
 
70
  top_probs, top_indices = torch.topk(probabilities, top_k)
71
  top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
72
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
73
- prob_plot = plot_probabilities(prob_data)
74
 
75
- prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
 
 
 
76
 
77
- attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
 
 
78
 
79
- return prob_text, attention_heatmap, prob_plot
80
  except Exception as e:
81
  return f"Erreur lors de l'analyse : {str(e)}", None, None
82
 
@@ -108,25 +114,51 @@ def plot_probabilities(prob_data):
108
  probs = list(prob_data.values())
109
 
110
  fig, ax = plt.subplots(figsize=(10, 5))
111
- sns.barplot(x=words, y=probs, ax=ax)
112
  ax.set_title("Probabilités des tokens suivants les plus probables")
113
  ax.set_xlabel("Tokens")
114
  ax.set_ylabel("Probabilité")
115
- plt.xticks(rotation=45)
 
 
 
 
 
 
 
 
116
  plt.tight_layout()
117
  return fig
118
 
119
- def plot_attention_alternative(input_ids, last_token_logits):
120
  input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
121
- attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
122
- top_k = min(len(input_tokens), 10) # Limiter à 10 tokens pour la lisibilité
123
- top_attention_scores, _ = torch.topk(attention_scores, top_k)
124
-
125
- fig, ax = plt.subplots(figsize=(12, 6))
126
- sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
127
- ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
128
- ax.set_yticklabels(["Attention"], rotation=0)
129
- ax.set_title("Scores d'attention pour les derniers tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  plt.tight_layout()
131
  return fig
132
 
@@ -137,7 +169,7 @@ def reset():
137
  return "", 1.0, 1.0, 50, None, None, None, None
138
 
139
  with gr.Blocks() as demo:
140
- gr.Markdown("# Analyse et génération de texte")
141
 
142
  with gr.Accordion("Sélection du modèle"):
143
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
@@ -155,7 +187,7 @@ with gr.Blocks() as demo:
155
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
156
 
157
  with gr.Row():
158
- attention_plot = gr.Plot(label="Visualisation de l'attention")
159
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
160
 
161
  generate_button = gr.Button("Générer le prochain mot")
@@ -166,12 +198,12 @@ with gr.Blocks() as demo:
166
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
167
  analyze_button.click(analyze_next_token,
168
  inputs=[input_text, temperature, top_p, top_k],
169
- outputs=[next_token_probs, attention_plot, prob_plot])
170
  generate_button.click(generate_text,
171
  inputs=[input_text, temperature, top_p, top_k],
172
  outputs=[generated_text])
173
  reset_button.click(reset,
174
- outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
175
 
176
  if __name__ == "__main__":
177
  demo.launch()
 
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
69
+
70
+ # Obtenir les 10 tokens les plus probables
71
+ top_k = 10
72
  top_probs, top_indices = torch.topk(probabilities, top_k)
73
  top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
74
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
 
75
 
76
+ # Créer un texte explicatif
77
+ prob_text = "Prochains tokens les plus probables :\n\n"
78
+ for word, prob in prob_data.items():
79
+ prob_text += f"{word}: {prob:.2%}\n"
80
 
81
+ # Créer les visualisations
82
+ prob_plot = plot_probabilities(prob_data)
83
+ importance_plot = plot_token_importance(inputs["input_ids"][0], last_token_logits)
84
 
85
+ return prob_text, importance_plot, prob_plot
86
  except Exception as e:
87
  return f"Erreur lors de l'analyse : {str(e)}", None, None
88
 
 
114
  probs = list(prob_data.values())
115
 
116
  fig, ax = plt.subplots(figsize=(10, 5))
117
+ bars = ax.bar(words, probs, color='lightgreen')
118
  ax.set_title("Probabilités des tokens suivants les plus probables")
119
  ax.set_xlabel("Tokens")
120
  ax.set_ylabel("Probabilité")
121
+ plt.xticks(rotation=45, ha='right')
122
+
123
+ # Ajouter les valeurs sur les barres
124
+ for bar in bars:
125
+ height = bar.get_height()
126
+ ax.text(bar.get_x() + bar.get_width()/2., height,
127
+ f'{height:.2%}',
128
+ ha='center', va='bottom')
129
+
130
  plt.tight_layout()
131
  return fig
132
 
133
+ def plot_token_importance(input_ids, last_token_logits):
134
  input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
135
+
136
+ # Calculer l'importance de chaque token
137
+ importances = torch.abs(last_token_logits).sum() / len(last_token_logits)
138
+ importances = importances.repeat(len(input_tokens))
139
+
140
+ # Normaliser les importances
141
+ importances = importances / importances.sum()
142
+
143
+ # Créer la figure
144
+ fig, ax = plt.subplots(figsize=(12, 3))
145
+
146
+ # Créer un graphique à barres
147
+ bars = ax.bar(range(len(input_tokens)), importances, color='skyblue')
148
+
149
+ # Ajouter les labels et le titre
150
+ ax.set_xticks(range(len(input_tokens)))
151
+ ax.set_xticklabels(input_tokens, rotation=45, ha='right')
152
+ ax.set_ylabel('Importance relative')
153
+ ax.set_title('Importance des tokens d\'entrée pour la prédiction')
154
+
155
+ # Ajouter les valeurs sur les barres
156
+ for bar in bars:
157
+ height = bar.get_height()
158
+ ax.text(bar.get_x() + bar.get_width()/2., height,
159
+ f'{height:.2%}',
160
+ ha='center', va='bottom')
161
+
162
  plt.tight_layout()
163
  return fig
164
 
 
169
  return "", 1.0, 1.0, 50, None, None, None, None
170
 
171
  with gr.Blocks() as demo:
172
+ gr.Markdown("# LLM & Bias ")
173
 
174
  with gr.Accordion("Sélection du modèle"):
175
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
 
187
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
188
 
189
  with gr.Row():
190
+ importance_plot = gr.Plot(label="Importance des tokens d'entrée")
191
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
192
 
193
  generate_button = gr.Button("Générer le prochain mot")
 
198
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
199
  analyze_button.click(analyze_next_token,
200
  inputs=[input_text, temperature, top_p, top_k],
201
+ outputs=[next_token_probs, importance_plot, prob_plot])
202
  generate_button.click(generate_text,
203
  inputs=[input_text, temperature, top_p, top_k],
204
  outputs=[generated_text])
205
  reset_button.click(reset,
206
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, importance_plot, prob_plot, generated_text])
207
 
208
  if __name__ == "__main__":
209
  demo.launch()