Woziii commited on
Commit
20095d9
1 Parent(s): b14462b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -203
app.py CHANGED
@@ -6,7 +6,8 @@ import os
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import numpy as np
9
- import time
 
10
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
@@ -17,12 +18,10 @@ models_info = {
17
  "Llama 2": {
18
  "7B": {"name": "meta-llama/Llama-2-7b-hf", "languages": ["en"]},
19
  "13B": {"name": "meta-llama/Llama-2-13b-hf", "languages": ["en"]},
20
- "70B": {"name": "meta-llama/Llama-2-70b-hf", "languages": ["en"]},
21
  },
22
  "Llama 3": {
23
- "8B": {"name": "meta-llama/Meta-Llama-3-8B", "languages": ["en"]},
24
  "3.2-3B": {"name": "meta-llama/Llama-3.2-3B", "languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"]},
25
- "3.1-8B": {"name": "meta-llama/Llama-3.1-8B", "languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"]},
26
  },
27
  },
28
  "Mistral AI": {
@@ -37,8 +36,7 @@ models_info = {
37
  "Google": {
38
  "Gemma": {
39
  "2B": {"name": "google/gemma-2-2b", "languages": ["en"]},
40
- "9B": {"name": "google/gemma-2-9b", "languages": ["en"]},
41
- "27B": {"name": "google/gemma-2-27b", "languages": ["en"]},
42
  },
43
  },
44
  "CroissantLLM": {
@@ -50,31 +48,29 @@ models_info = {
50
 
51
  # Paramètres recommandés pour chaque modèle
52
  model_parameters = {
53
- "meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
54
  "meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
55
- "meta-llama/Llama-2-70b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
56
- "meta-llama/Meta-Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
57
  "meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
58
- "meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
59
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
60
- "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
61
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
 
62
  "google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
63
- "google/gemma-2-9b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
64
- "google/gemma-2-27b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
65
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
66
  }
67
 
68
  # Variables globales
69
- model = None
70
- tokenizer = None
71
- selected_language = None
72
 
 
73
  def update_model_type(family):
74
  return gr.Dropdown(choices=list(models_info[family].keys()), value=None, interactive=True)
75
 
76
  def update_model_variation(family, model_type):
77
- return gr.Dropdown(choices=list(models_info[family][model_type].keys()), value=None, interactive=True)
 
 
78
 
79
  def update_selected_model(family, model_type, variation):
80
  if family and model_type and variation:
@@ -82,83 +78,48 @@ def update_selected_model(family, model_type, variation):
82
  return model_name, gr.Dropdown(choices=models_info[family][model_type][variation]["languages"], value=models_info[family][model_type][variation]["languages"][0], visible=True, interactive=True)
83
  return "", gr.Dropdown(visible=False)
84
 
85
- def load_model(model_name, progress=gr.Progress()):
86
- global model, tokenizer
87
  try:
88
- progress(0, desc="Chargement du tokenizer")
89
- tokenizer = AutoTokenizer.from_pretrained(model_name)
90
- progress(0.5, desc="Chargement du modèle")
91
-
92
- # Configurations spécifiques par modèle
93
- if "mixtral" in model_name.lower():
94
- model = AutoModelForCausalLM.from_pretrained(
95
- model_name,
96
- torch_dtype=torch.float16,
97
- device_map="auto",
98
- load_in_8bit=True
99
- )
100
- else:
101
- model = AutoModelForCausalLM.from_pretrained(
102
- model_name,
103
- torch_dtype=torch.float16,
104
- device_map="auto"
105
- )
106
-
107
- if tokenizer.pad_token is None:
108
- tokenizer.pad_token = tokenizer.eos_token
109
-
110
- progress(1.0, desc="Modèle chargé")
111
-
112
- # Recherche des langues disponibles pour le modèle sélectionné
113
- available_languages = next(
114
- (info["languages"] for family in models_info.values()
115
- for model_type in family.values()
116
- for variation in model_type.values()
117
- if variation["name"] == model_name),
118
- ["en"] # Défaut à l'anglais si non trouvé
119
- )
120
-
121
- # Mise à jour des sliders avec les valeurs recommandées
122
- params = model_parameters[model_name]
123
- return (
124
- f"Modèle {model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
125
- gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
126
- params["temperature"],
127
- params["top_p"],
128
- params["top_k"]
129
- )
130
  except Exception as e:
131
- return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None
132
 
133
  def set_language(lang):
134
- global selected_language
135
- selected_language = lang
136
  return f"Langue sélectionnée : {lang}"
137
 
138
- def ensure_token_display(token):
139
- """Assure que le token est affiché correctement."""
140
  if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
141
  return tokenizer.decode([int(token)])
142
  return token
143
 
144
- def analyze_next_token(input_text, temperature, top_p, top_k):
145
- global model, tokenizer, selected_language
 
146
 
147
- if model is None or tokenizer is None:
148
- return "Veuillez d'abord charger un modèle.", None, None
149
-
150
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
151
 
152
  try:
 
153
  with torch.no_grad():
154
  outputs = model(**inputs)
155
 
156
  last_token_logits = outputs.logits[0, -1, :]
157
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
158
 
159
- top_k = 10
160
  top_probs, top_indices = torch.topk(probabilities, top_k)
161
- top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
162
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
163
 
164
  prob_text = "Prochains tokens les plus probables :\n\n"
@@ -166,80 +127,92 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
166
  prob_text += f"{word}: {prob:.2%}\n"
167
 
168
  prob_plot = plot_probabilities(prob_data)
169
- attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu())
170
 
 
171
  return prob_text, attention_plot, prob_plot
172
  except Exception as e:
173
  return f"Erreur lors de l'analyse : {str(e)}", None, None
174
 
175
- def generate_text(input_text, temperature, top_p, top_k):
176
- global model, tokenizer, selected_language
 
177
 
178
- if model is None or tokenizer is None:
179
- return "Veuillez d'abord charger un modèle."
180
-
181
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
182
 
183
  try:
 
184
  with torch.no_grad():
185
  outputs = model.generate(
186
  **inputs,
187
- max_new_tokens=10,
188
  temperature=temperature,
189
  top_p=top_p,
190
  top_k=top_k
191
  )
192
 
193
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
194
  return generated_text
195
  except Exception as e:
196
  return f"Erreur lors de la génération : {str(e)}"
197
 
198
  def plot_probabilities(prob_data):
199
- words = list(prob_data.keys())
200
- probs = list(prob_data.values())
201
-
202
- fig, ax = plt.subplots(figsize=(12, 6))
203
- bars = ax.bar(range(len(words)), probs, color='lightgreen')
204
- ax.set_title("Probabilités des tokens suivants les plus probables")
205
- ax.set_xlabel("Tokens")
206
- ax.set_ylabel("Probabilité")
207
-
208
- ax.set_xticks(range(len(words)))
209
- ax.set_xticklabels(words, rotation=45, ha='right')
210
-
211
- for i, (bar, word) in enumerate(zip(bars, words)):
212
- height = bar.get_height()
213
- ax.text(i, height, f'{height:.2%}',
214
- ha='center', va='bottom', rotation=0)
215
-
216
- plt.tight_layout()
217
- return fig
 
 
 
 
218
 
219
- def plot_attention(input_ids, last_token_logits):
220
- input_tokens = [ensure_token_display(tokenizer.decode([id])) for id in input_ids]
221
- attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
222
- top_k = min(len(input_tokens), 10)
223
- top_attention_scores, _ = torch.topk(attention_scores, top_k)
224
-
225
- fig, ax = plt.subplots(figsize=(14, 7))
226
- sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
227
- ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
228
- ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
229
- ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
230
-
231
- cbar = ax.collections[0].colorbar
232
- cbar.set_label("Score d'attention", fontsize=12)
233
- cbar.ax.tick_params(labelsize=10)
234
-
235
- plt.tight_layout()
236
- return fig
 
 
 
 
237
 
238
  def reset():
239
- global model, tokenizer, selected_language
240
- model = None
241
- tokenizer = None
242
- selected_language = None
 
 
243
  return (
244
  "", 1.0, 1.0, 50, None, None, None, None,
245
  gr.Dropdown(choices=list(models_info.keys()), value=None, interactive=True),
@@ -248,92 +221,139 @@ def reset():
248
  "", gr.Dropdown(visible=False), ""
249
  )
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  with gr.Blocks() as demo:
252
  gr.Markdown("# LLM&BIAS")
253
 
254
- with gr.Accordion("Sélection du modèle", open=True):
255
- with gr.Row():
256
- model_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille de modèle", interactive=True)
257
- model_type = gr.Dropdown(choices=[], label="Type de modèle", interactive=False)
258
- model_variation = gr.Dropdown(choices=[], label="Variation du modèle", interactive=False)
259
-
260
- selected_model = gr.Textbox(label="Modèle sélectionné", interactive=False)
261
- load_button = gr.Button("Charger le modèle")
262
- load_output = gr.Textbox(label="Statut du chargement")
263
- language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
264
- language_output = gr.Textbox(label="Langue sélectionnée")
265
-
266
- with gr.Row():
267
- temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
268
- top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
269
- top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
270
-
271
- input_text = gr.Textbox(label="Texte d'entrée", lines=3)
272
- analyze_button = gr.Button("Analyser le prochain token")
273
-
274
- next_token_probs = gr.Textbox(label="Probabilités du prochain token")
275
-
276
- with gr.Row():
277
- attention_plot = gr.Plot(label="Visualisation de l'attention")
278
- prob_plot = gr.Plot(label="Probabilités des tokens suivants")
279
-
280
- generate_button = gr.Button("Générer le prochain mot")
281
- generated_text = gr.Textbox(label="Texte généré")
282
-
283
- reset_button = gr.Button("Réinitialiser")
284
-
285
- # Événements pour la sélection du modèle
286
- model_family.change(
287
- update_model_type,
288
- inputs=[model_family],
289
- outputs=[model_type]
290
- )
291
-
292
- model_type.change(
293
- update_model_variation,
294
- inputs=[model_family, model_type],
295
- outputs=[model_variation]
296
- )
297
-
298
- model_variation.change(
299
- update_selected_model,
300
- inputs=[model_family, model_type, model_variation],
301
- outputs=[selected_model, language_dropdown]
302
- )
303
-
304
- load_button.click(
305
- load_model,
306
- inputs=[selected_model],
307
- outputs=[load_output, language_dropdown, temperature, top_p, top_k]
308
- )
309
-
310
- language_dropdown.change(
311
- set_language,
312
- inputs=[language_dropdown],
313
- outputs=[language_output]
314
- )
315
-
316
- analyze_button.click(
317
- analyze_next_token,
318
- inputs=[input_text, temperature, top_p, top_k],
319
- outputs=[next_token_probs, attention_plot, prob_plot]
320
- )
321
-
322
- generate_button.click(
323
- generate_text,
324
- inputs=[input_text, temperature, top_p, top_k],
325
- outputs=[generated_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  )
327
-
328
- reset_button.click(
329
- reset,
330
- outputs=[
331
- input_text, temperature, top_p, top_k,
332
- next_token_probs, attention_plot, prob_plot, generated_text,
333
- model_family, model_type, model_variation,
334
- selected_model, language_dropdown, language_output
335
- ]
336
  )
337
 
338
  if __name__ == "__main__":
339
- demo.launch()
 
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import numpy as np
9
+ import asyncio
10
+ import gc
11
 
12
  # Authentification
13
  login(token=os.environ["HF_TOKEN"])
 
18
  "Llama 2": {
19
  "7B": {"name": "meta-llama/Llama-2-7b-hf", "languages": ["en"]},
20
  "13B": {"name": "meta-llama/Llama-2-13b-hf", "languages": ["en"]},
 
21
  },
22
  "Llama 3": {
23
+ "8B": {"name": "meta-llama/Llama-3-8B", "languages": ["en"]},
24
  "3.2-3B": {"name": "meta-llama/Llama-3.2-3B", "languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"]},
 
25
  },
26
  },
27
  "Mistral AI": {
 
36
  "Google": {
37
  "Gemma": {
38
  "2B": {"name": "google/gemma-2-2b", "languages": ["en"]},
39
+ "7B": {"name": "google/gemma-2-7b", "languages": ["en"]},
 
40
  },
41
  },
42
  "CroissantLLM": {
 
48
 
49
  # Paramètres recommandés pour chaque modèle
50
  model_parameters = {
 
51
  "meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
52
+ "meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
53
+ "meta-llama/Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
54
  "meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
 
55
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
 
56
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
57
+ "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
58
  "google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
59
+ "google/gemma-2-7b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
 
60
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
61
  }
62
 
63
  # Variables globales
64
+ model_cache = {}
 
 
65
 
66
+ # Fonctions utilitaires
67
  def update_model_type(family):
68
  return gr.Dropdown(choices=list(models_info[family].keys()), value=None, interactive=True)
69
 
70
  def update_model_variation(family, model_type):
71
+ if family and model_type:
72
+ return gr.Dropdown(choices=list(models_info[family][model_type].keys()), value=None, interactive=True)
73
+ return gr.Dropdown(choices=[], value=None, interactive=False)
74
 
75
  def update_selected_model(family, model_type, variation):
76
  if family and model_type and variation:
 
78
  return model_name, gr.Dropdown(choices=models_info[family][model_type][variation]["languages"], value=models_info[family][model_type][variation]["languages"][0], visible=True, interactive=True)
79
  return "", gr.Dropdown(visible=False)
80
 
81
+ async def load_model_async(model_name, progress=gr.Progress()):
 
82
  try:
83
+ if model_name not in model_cache:
84
+ progress(0.1, f"Chargement du tokenizer pour {model_name}...")
85
+ tokenizer = await asyncio.to_thread(AutoTokenizer.from_pretrained, model_name)
86
+ progress(0.4, f"Chargement du modèle {model_name}...")
87
+ model = await asyncio.to_thread(AutoModelForCausalLM.from_pretrained, model_name,
88
+ torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True)
89
+ if tokenizer.pad_token is None:
90
+ tokenizer.pad_token = tokenizer.eos_token
91
+ model_cache[model_name] = (model, tokenizer)
92
+ progress(1.0, f"Modèle {model_name} chargé avec succès")
93
+ return f"Modèle {model_name} chargé avec succès"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
+ return f"Erreur lors du chargement du modèle {model_name} : {str(e)}"
96
 
97
  def set_language(lang):
 
 
98
  return f"Langue sélectionnée : {lang}"
99
 
100
+ def ensure_token_display(token, tokenizer):
 
101
  if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
102
  return tokenizer.decode([int(token)])
103
  return token
104
 
105
+ async def analyze_next_token(model_name, input_text, temperature, top_p, top_k, progress=gr.Progress()):
106
+ if model_name not in model_cache:
107
+ return "Veuillez d'abord charger le modèle", None, None
108
 
109
+ model, tokenizer = model_cache[model_name]
 
 
110
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
111
 
112
  try:
113
+ progress(0.5, "Analyse en cours...")
114
  with torch.no_grad():
115
  outputs = model(**inputs)
116
 
117
  last_token_logits = outputs.logits[0, -1, :]
118
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
119
 
120
+ top_k = min(10, top_k)
121
  top_probs, top_indices = torch.topk(probabilities, top_k)
122
+ top_words = [ensure_token_display(tokenizer.decode([idx.item()]), tokenizer) for idx in top_indices]
123
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
124
 
125
  prob_text = "Prochains tokens les plus probables :\n\n"
 
127
  prob_text += f"{word}: {prob:.2%}\n"
128
 
129
  prob_plot = plot_probabilities(prob_data)
130
+ attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu(), tokenizer)
131
 
132
+ progress(1.0, "Analyse terminée")
133
  return prob_text, attention_plot, prob_plot
134
  except Exception as e:
135
  return f"Erreur lors de l'analyse : {str(e)}", None, None
136
 
137
+ async def generate_text(model_name, input_text, temperature, top_p, top_k, progress=gr.Progress()):
138
+ if model_name not in model_cache:
139
+ return "Veuillez d'abord charger le modèle"
140
 
141
+ model, tokenizer = model_cache[model_name]
 
 
142
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
143
 
144
  try:
145
+ progress(0.5, "Génération en cours...")
146
  with torch.no_grad():
147
  outputs = model.generate(
148
  **inputs,
149
+ max_new_tokens=50,
150
  temperature=temperature,
151
  top_p=top_p,
152
  top_k=top_k
153
  )
154
 
155
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
156
+ progress(1.0, "Génération terminée")
157
  return generated_text
158
  except Exception as e:
159
  return f"Erreur lors de la génération : {str(e)}"
160
 
161
  def plot_probabilities(prob_data):
162
+ try:
163
+ words = list(prob_data.keys())
164
+ probs = list(prob_data.values())
165
+
166
+ fig, ax = plt.subplots(figsize=(12, 6))
167
+ bars = ax.bar(range(len(words)), probs, color='lightgreen')
168
+ ax.set_title("Probabilités des tokens suivants les plus probables")
169
+ ax.set_xlabel("Tokens")
170
+ ax.set_ylabel("Probabilité")
171
+
172
+ ax.set_xticks(range(len(words)))
173
+ ax.set_xticklabels(words, rotation=45, ha='right')
174
+
175
+ for i, (bar, word) in enumerate(zip(bars, words)):
176
+ height = bar.get_height()
177
+ ax.text(i, height, f'{height:.2%}',
178
+ ha='center', va='bottom', rotation=0)
179
+
180
+ plt.tight_layout()
181
+ return fig
182
+ except Exception as e:
183
+ print(f"Erreur lors de la création du graphique : {str(e)}")
184
+ return None
185
 
186
+ def plot_attention(input_ids, last_token_logits, tokenizer):
187
+ try:
188
+ input_tokens = [ensure_token_display(tokenizer.decode([id]), tokenizer) for id in input_ids]
189
+ attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
190
+ top_k = min(len(input_tokens), 10)
191
+ top_attention_scores, _ = torch.topk(attention_scores, top_k)
192
+
193
+ fig, ax = plt.subplots(figsize=(14, 7))
194
+ sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
195
+ ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
196
+ ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
197
+ ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
198
+
199
+ cbar = ax.collections[0].colorbar
200
+ cbar.set_label("Score d'attention", fontsize=12)
201
+ cbar.ax.tick_params(labelsize=10)
202
+
203
+ plt.tight_layout()
204
+ return fig
205
+ except Exception as e:
206
+ print(f"Erreur lors de la création du graphique d'attention : {str(e)}")
207
+ return None
208
 
209
  def reset():
210
+ global model_cache
211
+ for model in model_cache.values():
212
+ del model
213
+ model_cache.clear()
214
+ torch.cuda.empty_cache()
215
+ gc.collect()
216
  return (
217
  "", 1.0, 1.0, 50, None, None, None, None,
218
  gr.Dropdown(choices=list(models_info.keys()), value=None, interactive=True),
 
221
  "", gr.Dropdown(visible=False), ""
222
  )
223
 
224
+ def reset_comparison():
225
+ return [gr.Dropdown(choices=[], value=None) for _ in range(4)] + ["", "", gr.Dropdown(choices=[], value=None), 1.0, 1.0, 50, "", "", None, None, None, None]
226
+
227
+ async def compare_models(model1, model2, input_text, temp, top_p, top_k, progress=gr.Progress()):
228
+ if model1 not in model_cache or model2 not in model_cache:
229
+ return "Veuillez d'abord charger les deux modèles", "", None, None, None, None
230
+
231
+ progress(0.1, "Analyse du premier modèle...")
232
+ results1 = await analyze_next_token(model1, input_text, temp, top_p, top_k)
233
+ progress(0.4, "Analyse du second modèle...")
234
+ results2 = await analyze_next_token(model2, input_text, temp, top_p, top_k)
235
+ progress(1.0, "Comparaison terminée")
236
+ return (
237
+ results1[0], results2[0], # Probabilités du prochain token
238
+ results1[2], results2[2], # Graphiques de probabilités
239
+ results1[1], results2[1] # Graphiques d'attention
240
+ )
241
+
242
  with gr.Blocks() as demo:
243
  gr.Markdown("# LLM&BIAS")
244
 
245
+ with gr.Tabs():
246
+ with gr.Tab("Analyse individuelle"):
247
+ with gr.Accordion("Sélection du modèle", open=True):
248
+ with gr.Row():
249
+ model_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille de modèle", interactive=True)
250
+ model_type = gr.Dropdown(choices=[], label="Type de modèle", interactive=False)
251
+ model_variation = gr.Dropdown(choices=[], label="Variation du modèle", interactive=False)
252
+
253
+ selected_model = gr.Textbox(label="Modèle sélectionné", interactive=False)
254
+ load_button = gr.Button("Charger le modèle")
255
+ load_output = gr.Textbox(label="Statut du chargement")
256
+ language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
257
+ language_output = gr.Textbox(label="Langue sélectionnée")
258
+
259
+ with gr.Row():
260
+ temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
261
+ top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
262
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
263
+
264
+ input_text = gr.Textbox(label="Texte d'entrée", lines=3)
265
+ analyze_button = gr.Button("Analyser le prochain token")
266
+
267
+ next_token_probs = gr.Textbox(label="Probabilités du prochain token")
268
+
269
+ with gr.Row():
270
+ attention_plot = gr.Plot(label="Visualisation de l'attention")
271
+ prob_plot = gr.Plot(label="Probabilités des tokens suivants")
272
+
273
+ generate_button = gr.Button("Générer le texte")
274
+ generated_text = gr.Textbox(label="Texte généré")
275
+
276
+ reset_button = gr.Button("Réinitialiser")
277
+
278
+ with gr.Tab("Comparaison de modèles"):
279
+ with gr.Row():
280
+ model1_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille du modèle 1", interactive=True)
281
+ model1_type = gr.Dropdown(choices=[], label="Type du modèle 1", interactive=False)
282
+ model1_variation = gr.Dropdown(choices=[], label="Variation du modèle 1", interactive=False)
283
+
284
+ with gr.Row():
285
+ model2_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille du modèle 2", interactive=True)
286
+ model2_type = gr.Dropdown(choices=[], label="Type du modèle 2", interactive=False)
287
+ model2_variation = gr.Dropdown(choices=[], label="Variation du modèle 2", interactive=False)
288
+
289
+ model1_selected = gr.Textbox(label="Modèle 1 sélectionné", interactive=False)
290
+ model2_selected = gr.Textbox(label="Modèle 2 sélectionné", interactive=False)
291
+
292
+ load_models_button = gr.Button("Charger les modèles")
293
+ load_models_output = gr.Textbox(label="Statut du chargement des modèles")
294
+
295
+ comparison_language = gr.Dropdown(label="Langue pour la comparaison", choices=[], interactive=False)
296
+
297
+ with gr.Row():
298
+ comp_temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
299
+ comp_top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
300
+ comp_top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
301
+
302
+ comp_input_text = gr.Textbox(label="Texte d'entrée pour la comparaison", lines=3)
303
+ compare_button = gr.Button("Comparer les modèles")
304
+
305
+ with gr.Row():
306
+ model1_output = gr.Textbox(label="Probabilités du Modèle 1", lines=10)
307
+ model2_output = gr.Textbox(label="Probabilités du Modèle 2", lines=10)
308
+
309
+ with gr.Row():
310
+ model1_prob_plot = gr.Plot(label="Probabilités des tokens (Modèle 1)")
311
+ model2_prob_plot = gr.Plot(label="Probabilités des tokens (Modèle 2)")
312
+
313
+ with gr.Row():
314
+ model1_attention_plot = gr.Plot(label="Attention (Modèle 1)")
315
+ model2_attention_plot = gr.Plot(label="Attention (Modèle 2)")
316
+
317
+ comp_reset_button = gr.Button("Réinitialiser la comparaison")
318
+
319
+ # Événements pour l'onglet d'analyse individuelle
320
+ model_family.change(update_model_type, inputs=[model_family], outputs=[model_type])
321
+ model_type.change(update_model_variation, inputs=[model_family, model_type], outputs=[model_variation])
322
+ model_variation.change(update_selected_model, inputs=[model_family, model_type, model_variation], outputs=[selected_model, language_dropdown])
323
+ load_button.click(load_model_async, inputs=[selected_model], outputs=[load_output])
324
+ language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
325
+ analyze_button.click(analyze_next_token, inputs=[selected_model, input_text, temperature, top_p, top_k], outputs=[next_token_probs, attention_plot, prob_plot])
326
+ generate_button.click(generate_text, inputs=[selected_model, input_text, temperature, top_p, top_k], outputs=[generated_text])
327
+ reset_button.click(reset, outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, model_family, model_type, model_variation, selected_model, language_dropdown, language_output])
328
+
329
+ # Événements pour l'onglet de comparaison
330
+ model1_family.change(update_model_type, inputs=[model1_family], outputs=[model1_type])
331
+ model1_type.change(update_model_variation, inputs=[model1_family, model1_type], outputs=[model1_variation])
332
+ model1_variation.change(update_selected_model, inputs=[model1_family, model1_type, model1_variation], outputs=[model1_selected, comparison_language])
333
+
334
+ model2_family.change(update_model_type, inputs=[model2_family], outputs=[model2_type])
335
+ model2_type.change(update_model_variation, inputs=[model2_family, model2_type], outputs=[model2_variation])
336
+ model2_variation.change(update_selected_model, inputs=[model2_family, model2_type, model2_variation], outputs=[model2_selected, comparison_language])
337
+
338
+ async def load_both_models(model1, model2):
339
+ result1 = await load_model_async(model1)
340
+ result2 = await load_model_async(model2)
341
+ return f"Modèle 1: {result1}\nModèle 2: {result2}"
342
+
343
+ load_models_button.click(load_both_models, inputs=[model1_selected, model2_selected], outputs=[load_models_output])
344
+
345
+ compare_button.click(
346
+ compare_models,
347
+ inputs=[model1_selected, model2_selected, comp_input_text, comp_temperature, comp_top_p, comp_top_k],
348
+ outputs=[model1_output, model2_output, model1_prob_plot, model2_prob_plot, model1_attention_plot, model2_attention_plot]
349
  )
350
+
351
+ comp_reset_button.click(
352
+ reset_comparison,
353
+ outputs=[model1_type, model1_variation, model2_type, model2_variation, model1_selected, model2_selected, comparison_language,
354
+ comp_temperature, comp_top_p, comp_top_k, comp_input_text, model1_output, model2_output,
355
+ model1_prob_plot, model2_prob_plot, model1_attention_plot, model2_attention_plot]
 
 
 
356
  )
357
 
358
  if __name__ == "__main__":
359
+ demo.launch()