Woziii commited on
Commit
74a6012
1 Parent(s): 82d83fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -29
app.py CHANGED
@@ -66,13 +66,13 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
69
- top_k = 10
70
  top_probs, top_indices = torch.topk(probabilities, top_k)
71
- top_words = [tokenizer.decode(idx.item()).strip() for idx in top_indices]
72
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
73
  prob_plot = plot_probabilities(prob_data)
74
 
75
- prob_text = "\n".join([f"{word}: {prob:.2%}" for word, prob in prob_data.items()])
76
 
77
  attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
78
 
@@ -107,21 +107,12 @@ def plot_probabilities(prob_data):
107
  words = list(prob_data.keys())
108
  probs = list(prob_data.values())
109
 
110
- fig, ax = plt.subplots(figsize=(12, 6))
111
- bars = ax.bar(words, probs, color='skyblue')
112
- ax.set_title("Probabilités des 10 tokens suivants les plus probables", fontsize=16)
113
- ax.set_xlabel("Tokens", fontsize=12)
114
- ax.set_ylabel("Probabilité", fontsize=12)
115
- plt.xticks(rotation=45, ha='right', fontsize=10)
116
- plt.yticks(fontsize=10)
117
-
118
- # Ajouter les pourcentages au-dessus des barres
119
- for bar in bars:
120
- height = bar.get_height()
121
- ax.text(bar.get_x() + bar.get_width()/2., height,
122
- f'{height:.2%}',
123
- ha='center', va='bottom', fontsize=10)
124
-
125
  plt.tight_layout()
126
  return fig
127
 
@@ -131,17 +122,11 @@ def plot_attention_alternative(input_ids, last_token_logits):
131
  top_k = min(len(input_tokens), 10) # Limiter à 10 tokens pour la lisibilité
132
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
133
 
134
- fig, ax = plt.subplots(figsize=(14, 7))
135
- sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
136
- ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
137
- ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
138
- ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
139
-
140
- # Ajuster la colorbar
141
- cbar = ax.collections[0].colorbar
142
- cbar.set_label("Score d'attention", fontsize=12)
143
- cbar.ax.tick_params(labelsize=10)
144
-
145
  plt.tight_layout()
146
  return fig
147
 
 
66
 
67
  last_token_logits = outputs.logits[0, -1, :]
68
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
69
+ top_k = 5
70
  top_probs, top_indices = torch.topk(probabilities, top_k)
71
+ top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
72
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
73
  prob_plot = plot_probabilities(prob_data)
74
 
75
+ prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
76
 
77
  attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
78
 
 
107
  words = list(prob_data.keys())
108
  probs = list(prob_data.values())
109
 
110
+ fig, ax = plt.subplots(figsize=(10, 5))
111
+ sns.barplot(x=words, y=probs, ax=ax)
112
+ ax.set_title("Probabilités des tokens suivants les plus probables")
113
+ ax.set_xlabel("Tokens")
114
+ ax.set_ylabel("Probabilité")
115
+ plt.xticks(rotation=45)
 
 
 
 
 
 
 
 
 
116
  plt.tight_layout()
117
  return fig
118
 
 
122
  top_k = min(len(input_tokens), 10) # Limiter à 10 tokens pour la lisibilité
123
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
124
 
125
+ fig, ax = plt.subplots(figsize=(12, 6))
126
+ sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
127
+ ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
128
+ ax.set_yticklabels(["Attention"], rotation=0)
129
+ ax.set_title("Scores d'attention pour les derniers tokens")
 
 
 
 
 
 
130
  plt.tight_layout()
131
  return fig
132