Woziii commited on
Commit
8869d77
1 Parent(s): 74a6012

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -40
app.py CHANGED
@@ -5,8 +5,6 @@ from huggingface_hub import login
5
  import os
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
- import numpy as np
9
- import time
10
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
@@ -35,19 +33,23 @@ tokenizer = None
35
  def load_model(model_name, progress=gr.Progress()):
36
  global model, tokenizer
37
  try:
38
- for i in progress.tqdm(range(100)):
39
- time.sleep(0.01) # Simuler le chargement
40
- if i == 25:
41
- tokenizer = AutoTokenizer.from_pretrained(model_name)
42
- elif i == 75:
43
- model = AutoModelForCausalLM.from_pretrained(
44
- model_name,
45
- torch_dtype=torch.float32,
46
- device_map="cpu",
47
- attn_implementation="eager"
48
- )
49
- if tokenizer.pad_token is None:
50
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
51
  return f"Modèle {model_name} chargé avec succès."
52
  except Exception as e:
53
  return f"Erreur lors du chargement du modèle : {str(e)}"
@@ -58,7 +60,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
58
  if model is None or tokenizer is None:
59
  return "Veuillez d'abord charger un modèle.", None, None
60
 
61
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
62
 
63
  try:
64
  with torch.no_grad():
@@ -74,9 +76,12 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
74
 
75
  prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
76
 
77
- attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
 
 
 
78
 
79
- return prob_text, attention_heatmap, prob_plot
80
  except Exception as e:
81
  return f"Erreur lors de l'analyse : {str(e)}", None, None
82
 
@@ -86,20 +91,20 @@ def generate_text(input_text, temperature, top_p, top_k):
86
  if model is None or tokenizer is None:
87
  return "Veuillez d'abord charger un modèle."
88
 
89
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
90
 
91
  try:
92
  with torch.no_grad():
93
  outputs = model.generate(
94
  **inputs,
95
- max_new_tokens=1,
96
  temperature=temperature,
97
  top_p=top_p,
98
  top_k=top_k
99
  )
100
 
101
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
102
- return generated_text # Retourne l'input + le nouveau mot
103
  except Exception as e:
104
  return f"Erreur lors de la génération : {str(e)}"
105
 
@@ -116,20 +121,6 @@ def plot_probabilities(prob_data):
116
  plt.tight_layout()
117
  return fig
118
 
119
- def plot_attention_alternative(input_ids, last_token_logits):
120
- input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
121
- attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
122
- top_k = min(len(input_tokens), 10) # Limiter à 10 tokens pour la lisibilité
123
- top_attention_scores, _ = torch.topk(attention_scores, top_k)
124
-
125
- fig, ax = plt.subplots(figsize=(12, 6))
126
- sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
127
- ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
128
- ax.set_yticklabels(["Attention"], rotation=0)
129
- ax.set_title("Scores d'attention pour les derniers tokens")
130
- plt.tight_layout()
131
- return fig
132
-
133
  def reset():
134
  global model, tokenizer
135
  model = None
@@ -153,10 +144,9 @@ with gr.Blocks() as demo:
153
  analyze_button = gr.Button("Analyser le prochain token")
154
 
155
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
 
156
 
157
- with gr.Row():
158
- attention_plot = gr.Plot(label="Visualisation de l'attention")
159
- prob_plot = gr.Plot(label="Probabilités des tokens suivants")
160
 
161
  generate_button = gr.Button("Générer le prochain mot")
162
  generated_text = gr.Textbox(label="Texte généré")
@@ -166,12 +156,12 @@ with gr.Blocks() as demo:
166
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
167
  analyze_button.click(analyze_next_token,
168
  inputs=[input_text, temperature, top_p, top_k],
169
- outputs=[next_token_probs, attention_plot, prob_plot])
170
  generate_button.click(generate_text,
171
  inputs=[input_text, temperature, top_p, top_k],
172
  outputs=[generated_text])
173
  reset_button.click(reset,
174
- outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
175
 
176
  if __name__ == "__main__":
177
  demo.launch()
 
5
  import os
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
 
 
8
 
9
  # Authentification
10
  login(token=os.environ["HF_TOKEN"])
 
33
  def load_model(model_name, progress=gr.Progress()):
34
  global model, tokenizer
35
  try:
36
+ progress(0, desc="Chargement du tokenizer")
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ progress(0.3, desc="Tokenizer chargé")
39
+
40
+ progress(0.3, desc="Chargement du modèle")
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_name,
43
+ torch_dtype=torch.bfloat16,
44
+ device_map="auto",
45
+ attn_implementation="eager"
46
+ )
47
+ progress(0.9, desc="Modèle chargé")
48
+
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
+ progress(1.0, desc="Chargement terminé")
53
  return f"Modèle {model_name} chargé avec succès."
54
  except Exception as e:
55
  return f"Erreur lors du chargement du modèle : {str(e)}"
 
60
  if model is None or tokenizer is None:
61
  return "Veuillez d'abord charger un modèle.", None, None
62
 
63
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
64
 
65
  try:
66
  with torch.no_grad():
 
76
 
77
  prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
78
 
79
+ # Simplification de l'affichage de l'attention
80
+ attention_text = "Attention non disponible pour ce modèle"
81
+ if hasattr(outputs, 'attentions') and outputs.attentions is not None:
82
+ attention_text = "Attention disponible"
83
 
84
+ return prob_text, attention_text, prob_plot
85
  except Exception as e:
86
  return f"Erreur lors de l'analyse : {str(e)}", None, None
87
 
 
91
  if model is None or tokenizer is None:
92
  return "Veuillez d'abord charger un modèle."
93
 
94
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
95
 
96
  try:
97
  with torch.no_grad():
98
  outputs = model.generate(
99
  **inputs,
100
+ max_new_tokens=1, # Génère seulement le prochain mot
101
  temperature=temperature,
102
  top_p=top_p,
103
  top_k=top_k
104
  )
105
 
106
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
107
+ return generated_text # Retourne l'input + le nouveau mot généré
108
  except Exception as e:
109
  return f"Erreur lors de la génération : {str(e)}"
110
 
 
121
  plt.tight_layout()
122
  return fig
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def reset():
125
  global model, tokenizer
126
  model = None
 
144
  analyze_button = gr.Button("Analyser le prochain token")
145
 
146
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
147
+ attention_info = gr.Textbox(label="Information sur l'attention")
148
 
149
+ prob_plot = gr.Plot(label="Probabilités des tokens suivants")
 
 
150
 
151
  generate_button = gr.Button("Générer le prochain mot")
152
  generated_text = gr.Textbox(label="Texte généré")
 
156
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
157
  analyze_button.click(analyze_next_token,
158
  inputs=[input_text, temperature, top_p, top_k],
159
+ outputs=[next_token_probs, attention_info, prob_plot])
160
  generate_button.click(generate_text,
161
  inputs=[input_text, temperature, top_p, top_k],
162
  outputs=[generated_text])
163
  reset_button.click(reset,
164
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_info, prob_plot, generated_text])
165
 
166
  if __name__ == "__main__":
167
  demo.launch()