Woziii commited on
Commit
f18c3eb
1 Parent(s): 8869d77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -29
app.py CHANGED
@@ -5,6 +5,8 @@ from huggingface_hub import login
5
  import os
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
 
 
8
 
9
  # Authentification
10
  login(token=os.environ["HF_TOKEN"])
@@ -35,32 +37,33 @@ def load_model(model_name, progress=gr.Progress()):
35
  try:
36
  progress(0, desc="Chargement du tokenizer")
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
- progress(0.3, desc="Tokenizer chargé")
39
-
40
- progress(0.3, desc="Chargement du modèle")
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_name,
43
- torch_dtype=torch.bfloat16,
44
- device_map="auto",
45
  attn_implementation="eager"
46
  )
47
- progress(0.9, desc="Modèle chargé")
48
-
49
  if tokenizer.pad_token is None:
50
  tokenizer.pad_token = tokenizer.eos_token
51
-
52
- progress(1.0, desc="Chargement terminé")
53
  return f"Modèle {model_name} chargé avec succès."
54
  except Exception as e:
55
  return f"Erreur lors du chargement du modèle : {str(e)}"
56
 
 
 
 
 
 
 
57
  def analyze_next_token(input_text, temperature, top_p, top_k):
58
  global model, tokenizer
59
 
60
  if model is None or tokenizer is None:
61
  return "Veuillez d'abord charger un modèle.", None, None
62
 
63
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
64
 
65
  try:
66
  with torch.no_grad():
@@ -68,20 +71,20 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
68
 
69
  last_token_logits = outputs.logits[0, -1, :]
70
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
71
- top_k = 5
 
72
  top_probs, top_indices = torch.topk(probabilities, top_k)
73
- top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
74
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
75
- prob_plot = plot_probabilities(prob_data)
76
 
77
- prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
 
 
78
 
79
- # Simplification de l'affichage de l'attention
80
- attention_text = "Attention non disponible pour ce modèle"
81
- if hasattr(outputs, 'attentions') and outputs.attentions is not None:
82
- attention_text = "Attention disponible"
83
 
84
- return prob_text, attention_text, prob_plot
85
  except Exception as e:
86
  return f"Erreur lors de l'analyse : {str(e)}", None, None
87
 
@@ -91,20 +94,20 @@ def generate_text(input_text, temperature, top_p, top_k):
91
  if model is None or tokenizer is None:
92
  return "Veuillez d'abord charger un modèle."
93
 
94
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
95
 
96
  try:
97
  with torch.no_grad():
98
  outputs = model.generate(
99
  **inputs,
100
- max_new_tokens=1, # Génère seulement le prochain mot
101
  temperature=temperature,
102
  top_p=top_p,
103
  top_k=top_k
104
  )
105
 
106
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
107
- return generated_text # Retourne l'input + le nouveau mot généré
108
  except Exception as e:
109
  return f"Erreur lors de la génération : {str(e)}"
110
 
@@ -112,12 +115,39 @@ def plot_probabilities(prob_data):
112
  words = list(prob_data.keys())
113
  probs = list(prob_data.values())
114
 
115
- fig, ax = plt.subplots(figsize=(10, 5))
116
- sns.barplot(x=words, y=probs, ax=ax)
117
  ax.set_title("Probabilités des tokens suivants les plus probables")
118
  ax.set_xlabel("Tokens")
119
  ax.set_ylabel("Probabilité")
120
- plt.xticks(rotation=45)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  plt.tight_layout()
122
  return fig
123
 
@@ -144,9 +174,10 @@ with gr.Blocks() as demo:
144
  analyze_button = gr.Button("Analyser le prochain token")
145
 
146
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
147
- attention_info = gr.Textbox(label="Information sur l'attention")
148
 
149
- prob_plot = gr.Plot(label="Probabilités des tokens suivants")
 
 
150
 
151
  generate_button = gr.Button("Générer le prochain mot")
152
  generated_text = gr.Textbox(label="Texte généré")
@@ -156,12 +187,12 @@ with gr.Blocks() as demo:
156
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
157
  analyze_button.click(analyze_next_token,
158
  inputs=[input_text, temperature, top_p, top_k],
159
- outputs=[next_token_probs, attention_info, prob_plot])
160
  generate_button.click(generate_text,
161
  inputs=[input_text, temperature, top_p, top_k],
162
  outputs=[generated_text])
163
  reset_button.click(reset,
164
- outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_info, prob_plot, generated_text])
165
 
166
  if __name__ == "__main__":
167
  demo.launch()
 
5
  import os
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
+ import numpy as np
9
+ import time
10
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
 
37
  try:
38
  progress(0, desc="Chargement du tokenizer")
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
+ progress(0.5, desc="Chargement du modèle")
 
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_name,
43
+ torch_dtype=torch.float32,
44
+ device_map="cpu",
45
  attn_implementation="eager"
46
  )
 
 
47
  if tokenizer.pad_token is None:
48
  tokenizer.pad_token = tokenizer.eos_token
49
+ progress(1.0, desc="Modèle chargé")
 
50
  return f"Modèle {model_name} chargé avec succès."
51
  except Exception as e:
52
  return f"Erreur lors du chargement du modèle : {str(e)}"
53
 
54
+ def ensure_token_display(token):
55
+ """Assure que le token est affiché correctement."""
56
+ if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
57
+ return tokenizer.decode([int(token)])
58
+ return token
59
+
60
  def analyze_next_token(input_text, temperature, top_p, top_k):
61
  global model, tokenizer
62
 
63
  if model is None or tokenizer is None:
64
  return "Veuillez d'abord charger un modèle.", None, None
65
 
66
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
67
 
68
  try:
69
  with torch.no_grad():
 
71
 
72
  last_token_logits = outputs.logits[0, -1, :]
73
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
74
+
75
+ top_k = 10
76
  top_probs, top_indices = torch.topk(probabilities, top_k)
77
+ top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
78
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
 
79
 
80
+ prob_text = "Prochains tokens les plus probables :\n\n"
81
+ for word, prob in prob_data.items():
82
+ prob_text += f"{word}: {prob:.2%}\n"
83
 
84
+ prob_plot = plot_probabilities(prob_data)
85
+ attention_plot = plot_attention(inputs["input_ids"][0], last_token_logits)
 
 
86
 
87
+ return prob_text, attention_plot, prob_plot
88
  except Exception as e:
89
  return f"Erreur lors de l'analyse : {str(e)}", None, None
90
 
 
94
  if model is None or tokenizer is None:
95
  return "Veuillez d'abord charger un modèle."
96
 
97
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
98
 
99
  try:
100
  with torch.no_grad():
101
  outputs = model.generate(
102
  **inputs,
103
+ max_new_tokens=1,
104
  temperature=temperature,
105
  top_p=top_p,
106
  top_k=top_k
107
  )
108
 
109
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
110
+ return generated_text
111
  except Exception as e:
112
  return f"Erreur lors de la génération : {str(e)}"
113
 
 
115
  words = list(prob_data.keys())
116
  probs = list(prob_data.values())
117
 
118
+ fig, ax = plt.subplots(figsize=(12, 6))
119
+ bars = ax.bar(range(len(words)), probs, color='lightgreen')
120
  ax.set_title("Probabilités des tokens suivants les plus probables")
121
  ax.set_xlabel("Tokens")
122
  ax.set_ylabel("Probabilité")
123
+
124
+ ax.set_xticks(range(len(words)))
125
+ ax.set_xticklabels(words, rotation=45, ha='right')
126
+
127
+ for i, (bar, word) in enumerate(zip(bars, words)):
128
+ height = bar.get_height()
129
+ ax.text(i, height, f'{height:.2%}',
130
+ ha='center', va='bottom', rotation=0)
131
+
132
+ plt.tight_layout()
133
+ return fig
134
+
135
+ def plot_attention(input_ids, last_token_logits):
136
+ input_tokens = [ensure_token_display(tokenizer.decode([id])) for id in input_ids]
137
+ attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
138
+ top_k = min(len(input_tokens), 10)
139
+ top_attention_scores, _ = torch.topk(attention_scores, top_k)
140
+
141
+ fig, ax = plt.subplots(figsize=(14, 7))
142
+ sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
143
+ ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
144
+ ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
145
+ ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
146
+
147
+ cbar = ax.collections[0].colorbar
148
+ cbar.set_label("Score d'attention", fontsize=12)
149
+ cbar.ax.tick_params(labelsize=10)
150
+
151
  plt.tight_layout()
152
  return fig
153
 
 
174
  analyze_button = gr.Button("Analyser le prochain token")
175
 
176
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
 
177
 
178
+ with gr.Row():
179
+ attention_plot = gr.Plot(label="Visualisation de l'attention")
180
+ prob_plot = gr.Plot(label="Probabilités des tokens suivants")
181
 
182
  generate_button = gr.Button("Générer le prochain mot")
183
  generated_text = gr.Textbox(label="Texte généré")
 
187
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
188
  analyze_button.click(analyze_next_token,
189
  inputs=[input_text, temperature, top_p, top_k],
190
+ outputs=[next_token_probs, attention_plot, prob_plot])
191
  generate_button.click(generate_text,
192
  inputs=[input_text, temperature, top_p, top_k],
193
  outputs=[generated_text])
194
  reset_button.click(reset,
195
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
196
 
197
  if __name__ == "__main__":
198
  demo.launch()