Woziii commited on
Commit
bc7e16f
1 Parent(s): 3467291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -30
app.py CHANGED
@@ -62,28 +62,21 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
62
 
63
  try:
64
  with torch.no_grad():
65
- outputs = model(**inputs, output_attentions=True)
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
69
-
70
- # Obtenir les 10 tokens les plus probables
71
  top_k = 10
72
  top_probs, top_indices = torch.topk(probabilities, top_k)
73
  top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
74
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
 
75
 
76
- # Créer un texte explicatif
77
- prob_text = "Prochains tokens les plus probables :\n\n"
78
- for word, prob in prob_data.items():
79
- escaped_word = word.replace("<", "&lt;").replace(">", "&gt;")
80
- prob_text += f"{escaped_word}: {prob:.2%}\n"
81
 
82
- # Créer les visualisations
83
- prob_plot = plot_probabilities(prob_data)
84
- attention_plot = plot_attention(inputs["input_ids"][0], outputs.attentions)
85
 
86
- return prob_text, attention_plot, prob_plot
87
  except Exception as e:
88
  return f"Erreur lors de l'analyse : {str(e)}", None, None
89
 
@@ -115,32 +108,40 @@ def plot_probabilities(prob_data):
115
  probs = list(prob_data.values())
116
 
117
  fig, ax = plt.subplots(figsize=(12, 6))
118
- bars = ax.bar(range(len(words)), probs, color='lightgreen')
119
- ax.set_title("Probabilités des tokens suivants les plus probables")
120
- ax.set_xlabel("Tokens")
121
- ax.set_ylabel("Probabilité")
122
-
123
- ax.set_xticks(range(len(words)))
124
- ax.set_xticklabels(words, rotation=45, ha='right')
125
-
126
- for i, (bar, word) in enumerate(zip(bars, words)):
127
  height = bar.get_height()
128
- ax.text(i, height, f'{word}\n{height:.2%}',
129
- ha='center', va='bottom', rotation=0)
 
130
 
131
  plt.tight_layout()
132
  return fig
133
 
134
- def plot_attention(input_ids, attention_outputs):
135
  input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
 
 
 
136
 
137
- # Prendre la moyenne des attentions sur toutes les couches et têtes
138
- attention = torch.mean(torch.cat(attention_outputs), dim=(0, 1)).cpu().numpy()
 
 
 
139
 
140
- fig, ax = plt.subplots(figsize=(12, 10))
141
- sns.heatmap(attention, annot=True, cmap="YlOrRd", xticklabels=input_tokens, yticklabels=input_tokens, ax=ax)
 
 
142
 
143
- ax.set_title("Carte d'attention moyenne")
144
  plt.tight_layout()
145
  return fig
146
 
@@ -151,7 +152,7 @@ def reset():
151
  return "", 1.0, 1.0, 50, None, None, None, None
152
 
153
  with gr.Blocks() as demo:
154
- gr.Markdown("# LLM & Bias")
155
 
156
  with gr.Accordion("Sélection du modèle"):
157
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
 
62
 
63
  try:
64
  with torch.no_grad():
65
+ outputs = model(**inputs)
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
 
 
69
  top_k = 10
70
  top_probs, top_indices = torch.topk(probabilities, top_k)
71
  top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
72
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
73
+ prob_plot = plot_probabilities(prob_data)
74
 
75
+ prob_text = "\n".join([f"{word}: {prob:.2%}" for word, prob in prob_data.items()])
 
 
 
 
76
 
77
+ attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
 
 
78
 
79
+ return prob_text, attention_heatmap, prob_plot
80
  except Exception as e:
81
  return f"Erreur lors de l'analyse : {str(e)}", None, None
82
 
 
108
  probs = list(prob_data.values())
109
 
110
  fig, ax = plt.subplots(figsize=(12, 6))
111
+ bars = ax.bar(words, probs, color='skyblue')
112
+ ax.set_title("Probabilités des 10 tokens suivants les plus probables", fontsize=16)
113
+ ax.set_xlabel("Tokens", fontsize=12)
114
+ ax.set_ylabel("Probabilité", fontsize=12)
115
+ plt.xticks(rotation=45, ha='right', fontsize=10)
116
+ plt.yticks(fontsize=10)
117
+
118
+ # Ajouter les pourcentages au-dessus des barres
119
+ for bar in bars:
120
  height = bar.get_height()
121
+ ax.text(bar.get_x() + bar.get_width()/2., height,
122
+ f'{height:.2%}',
123
+ ha='center', va='bottom', fontsize=10)
124
 
125
  plt.tight_layout()
126
  return fig
127
 
128
+ def plot_attention_alternative(input_ids, last_token_logits):
129
  input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
130
+ attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
131
+ top_k = min(len(input_tokens), 10) # Limiter à 10 tokens pour la lisibilité
132
+ top_attention_scores, _ = torch.topk(attention_scores, top_k)
133
 
134
+ fig, ax = plt.subplots(figsize=(14, 7))
135
+ sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
136
+ ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
137
+ ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
138
+ ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
139
 
140
+ # Ajuster la colorbar
141
+ cbar = ax.collections[0].colorbar
142
+ cbar.set_label("Score d'attention", fontsize=12)
143
+ cbar.ax.tick_params(labelsize=10)
144
 
 
145
  plt.tight_layout()
146
  return fig
147
 
 
152
  return "", 1.0, 1.0, 50, None, None, None, None
153
 
154
  with gr.Blocks() as demo:
155
+ gr.Markdown("# Analyse et génération de texte")
156
 
157
  with gr.Accordion("Sélection du modèle"):
158
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")