Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -66,17 +66,23 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
66 |
|
67 |
last_token_logits = outputs.logits[0, -1, :]
|
68 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
69 |
-
|
|
|
|
|
70 |
top_probs, top_indices = torch.topk(probabilities, top_k)
|
71 |
top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
|
72 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
73 |
-
prob_plot = plot_probabilities(prob_data)
|
74 |
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
-
|
|
|
|
|
78 |
|
79 |
-
return prob_text,
|
80 |
except Exception as e:
|
81 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
82 |
|
@@ -108,25 +114,51 @@ def plot_probabilities(prob_data):
|
|
108 |
probs = list(prob_data.values())
|
109 |
|
110 |
fig, ax = plt.subplots(figsize=(10, 5))
|
111 |
-
|
112 |
ax.set_title("Probabilités des tokens suivants les plus probables")
|
113 |
ax.set_xlabel("Tokens")
|
114 |
ax.set_ylabel("Probabilité")
|
115 |
-
plt.xticks(rotation=45)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
plt.tight_layout()
|
117 |
return fig
|
118 |
|
119 |
-
def
|
120 |
input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
plt.tight_layout()
|
131 |
return fig
|
132 |
|
@@ -137,7 +169,7 @@ def reset():
|
|
137 |
return "", 1.0, 1.0, 50, None, None, None, None
|
138 |
|
139 |
with gr.Blocks() as demo:
|
140 |
-
gr.Markdown("#
|
141 |
|
142 |
with gr.Accordion("Sélection du modèle"):
|
143 |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
|
@@ -155,7 +187,7 @@ with gr.Blocks() as demo:
|
|
155 |
next_token_probs = gr.Textbox(label="Probabilités du prochain token")
|
156 |
|
157 |
with gr.Row():
|
158 |
-
|
159 |
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
|
160 |
|
161 |
generate_button = gr.Button("Générer le prochain mot")
|
@@ -166,12 +198,12 @@ with gr.Blocks() as demo:
|
|
166 |
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
|
167 |
analyze_button.click(analyze_next_token,
|
168 |
inputs=[input_text, temperature, top_p, top_k],
|
169 |
-
outputs=[next_token_probs,
|
170 |
generate_button.click(generate_text,
|
171 |
inputs=[input_text, temperature, top_p, top_k],
|
172 |
outputs=[generated_text])
|
173 |
reset_button.click(reset,
|
174 |
-
outputs=[input_text, temperature, top_p, top_k, next_token_probs,
|
175 |
|
176 |
if __name__ == "__main__":
|
177 |
demo.launch()
|
|
|
66 |
|
67 |
last_token_logits = outputs.logits[0, -1, :]
|
68 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
69 |
+
|
70 |
+
# Obtenir les 10 tokens les plus probables
|
71 |
+
top_k = 10
|
72 |
top_probs, top_indices = torch.topk(probabilities, top_k)
|
73 |
top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
|
74 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
|
|
75 |
|
76 |
+
# Créer un texte explicatif
|
77 |
+
prob_text = "Prochains tokens les plus probables :\n\n"
|
78 |
+
for word, prob in prob_data.items():
|
79 |
+
prob_text += f"{word}: {prob:.2%}\n"
|
80 |
|
81 |
+
# Créer les visualisations
|
82 |
+
prob_plot = plot_probabilities(prob_data)
|
83 |
+
importance_plot = plot_token_importance(inputs["input_ids"][0], last_token_logits)
|
84 |
|
85 |
+
return prob_text, importance_plot, prob_plot
|
86 |
except Exception as e:
|
87 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
88 |
|
|
|
114 |
probs = list(prob_data.values())
|
115 |
|
116 |
fig, ax = plt.subplots(figsize=(10, 5))
|
117 |
+
bars = ax.bar(words, probs, color='lightgreen')
|
118 |
ax.set_title("Probabilités des tokens suivants les plus probables")
|
119 |
ax.set_xlabel("Tokens")
|
120 |
ax.set_ylabel("Probabilité")
|
121 |
+
plt.xticks(rotation=45, ha='right')
|
122 |
+
|
123 |
+
# Ajouter les valeurs sur les barres
|
124 |
+
for bar in bars:
|
125 |
+
height = bar.get_height()
|
126 |
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
127 |
+
f'{height:.2%}',
|
128 |
+
ha='center', va='bottom')
|
129 |
+
|
130 |
plt.tight_layout()
|
131 |
return fig
|
132 |
|
133 |
+
def plot_token_importance(input_ids, last_token_logits):
|
134 |
input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
135 |
+
|
136 |
+
# Calculer l'importance de chaque token
|
137 |
+
importances = torch.abs(last_token_logits).sum() / len(last_token_logits)
|
138 |
+
importances = importances.repeat(len(input_tokens))
|
139 |
+
|
140 |
+
# Normaliser les importances
|
141 |
+
importances = importances / importances.sum()
|
142 |
+
|
143 |
+
# Créer la figure
|
144 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
145 |
+
|
146 |
+
# Créer un graphique à barres
|
147 |
+
bars = ax.bar(range(len(input_tokens)), importances, color='skyblue')
|
148 |
+
|
149 |
+
# Ajouter les labels et le titre
|
150 |
+
ax.set_xticks(range(len(input_tokens)))
|
151 |
+
ax.set_xticklabels(input_tokens, rotation=45, ha='right')
|
152 |
+
ax.set_ylabel('Importance relative')
|
153 |
+
ax.set_title('Importance des tokens d\'entrée pour la prédiction')
|
154 |
+
|
155 |
+
# Ajouter les valeurs sur les barres
|
156 |
+
for bar in bars:
|
157 |
+
height = bar.get_height()
|
158 |
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
159 |
+
f'{height:.2%}',
|
160 |
+
ha='center', va='bottom')
|
161 |
+
|
162 |
plt.tight_layout()
|
163 |
return fig
|
164 |
|
|
|
169 |
return "", 1.0, 1.0, 50, None, None, None, None
|
170 |
|
171 |
with gr.Blocks() as demo:
|
172 |
+
gr.Markdown("# LLM & Bias ")
|
173 |
|
174 |
with gr.Accordion("Sélection du modèle"):
|
175 |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
|
|
|
187 |
next_token_probs = gr.Textbox(label="Probabilités du prochain token")
|
188 |
|
189 |
with gr.Row():
|
190 |
+
importance_plot = gr.Plot(label="Importance des tokens d'entrée")
|
191 |
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
|
192 |
|
193 |
generate_button = gr.Button("Générer le prochain mot")
|
|
|
198 |
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
|
199 |
analyze_button.click(analyze_next_token,
|
200 |
inputs=[input_text, temperature, top_p, top_k],
|
201 |
+
outputs=[next_token_probs, importance_plot, prob_plot])
|
202 |
generate_button.click(generate_text,
|
203 |
inputs=[input_text, temperature, top_p, top_k],
|
204 |
outputs=[generated_text])
|
205 |
reset_button.click(reset,
|
206 |
+
outputs=[input_text, temperature, top_p, top_k, next_token_probs, importance_plot, prob_plot, generated_text])
|
207 |
|
208 |
if __name__ == "__main__":
|
209 |
demo.launch()
|