Woziii commited on
Commit
3226776
1 Parent(s): 0c7cad3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -32
app.py CHANGED
@@ -46,7 +46,7 @@ def load_model(model_name):
46
  except Exception as e:
47
  return f"Erreur lors du chargement du modèle : {str(e)}"
48
 
49
- def generate_text(input_text, temperature, top_p, top_k):
50
  global model, tokenizer
51
 
52
  if model is None or tokenizer is None:
@@ -56,39 +56,48 @@ def generate_text(input_text, temperature, top_p, top_k):
56
 
57
  try:
58
  with torch.no_grad():
59
- outputs = model.generate(
60
- **inputs,
61
- max_new_tokens=50,
62
- temperature=temperature,
63
- top_p=top_p,
64
- top_k=top_k,
65
- output_attentions=True,
66
- return_dict_in_generate=True,
67
- output_scores=True
68
- )
69
 
70
- generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
 
 
 
 
 
 
71
 
72
- if hasattr(outputs, 'scores') and outputs.scores:
73
- last_token_logits = outputs.scores[-1][0]
74
- probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
75
- top_k = 5
76
- top_probs, top_indices = torch.topk(probabilities, top_k)
77
- top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
78
- prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
79
- prob_plot = plot_probabilities(prob_data)
80
- else:
81
- prob_plot = None
82
-
83
- if hasattr(outputs, 'attentions') and outputs.attentions:
84
  attention_data = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
85
  attention_plot = plot_attention(attention_data, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
86
  else:
87
  attention_plot = None
88
 
89
- return generated_text, attention_plot, prob_plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
- return f"Erreur lors de la génération : {str(e)}", None, None
92
 
93
  def plot_attention(attention, tokens):
94
  fig, ax = plt.subplots(figsize=(10, 10))
@@ -119,10 +128,10 @@ def reset():
119
  global model, tokenizer
120
  model = None
121
  tokenizer = None
122
- return "", 1.0, 1.0, 50, None, None, None
123
 
124
  with gr.Blocks() as demo:
125
- gr.Markdown("# Générateur de texte avec visualisation d'attention")
126
 
127
  with gr.Accordion("Sélection du modèle"):
128
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
@@ -135,22 +144,28 @@ with gr.Blocks() as demo:
135
  top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
136
 
137
  input_text = gr.Textbox(label="Texte d'entrée", lines=3)
138
- generate_button = gr.Button("Générer")
 
139
 
140
- output_text = gr.Textbox(label="Texte généré", lines=5)
141
 
142
  with gr.Row():
143
  attention_plot = gr.Plot(label="Visualisation de l'attention")
144
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
145
 
 
 
146
  reset_button = gr.Button("Réinitialiser")
147
 
148
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
 
 
 
149
  generate_button.click(generate_text,
150
  inputs=[input_text, temperature, top_p, top_k],
151
- outputs=[output_text, attention_plot, prob_plot])
152
  reset_button.click(reset,
153
- outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
154
 
155
  if __name__ == "__main__":
156
  demo.launch()
 
46
  except Exception as e:
47
  return f"Erreur lors du chargement du modèle : {str(e)}"
48
 
49
+ def analyze_next_token(input_text, temperature, top_p, top_k):
50
  global model, tokenizer
51
 
52
  if model is None or tokenizer is None:
 
56
 
57
  try:
58
  with torch.no_grad():
59
+ outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
60
 
61
+ last_token_logits = outputs.logits[0, -1, :]
62
+ probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
63
+ top_k = 5
64
+ top_probs, top_indices = torch.topk(probabilities, top_k)
65
+ top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
66
+ prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
67
+ prob_plot = plot_probabilities(prob_data)
68
 
69
+ if hasattr(outputs, 'attentions') and outputs.attentions is not None:
 
 
 
 
 
 
 
 
 
 
 
70
  attention_data = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
71
  attention_plot = plot_attention(attention_data, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
72
  else:
73
  attention_plot = None
74
 
75
+ return "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()]), attention_plot, prob_plot
76
+ except Exception as e:
77
+ return f"Erreur lors de l'analyse : {str(e)}", None, None
78
+
79
+ def generate_text(input_text, temperature, top_p, top_k):
80
+ global model, tokenizer
81
+
82
+ if model is None or tokenizer is None:
83
+ return "Veuillez d'abord charger un modèle."
84
+
85
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
86
+
87
+ try:
88
+ with torch.no_grad():
89
+ outputs = model.generate(
90
+ **inputs,
91
+ max_new_tokens=50,
92
+ temperature=temperature,
93
+ top_p=top_p,
94
+ top_k=top_k
95
+ )
96
+
97
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+ return generated_text
99
  except Exception as e:
100
+ return f"Erreur lors de la génération : {str(e)}"
101
 
102
  def plot_attention(attention, tokens):
103
  fig, ax = plt.subplots(figsize=(10, 10))
 
128
  global model, tokenizer
129
  model = None
130
  tokenizer = None
131
+ return "", 1.0, 1.0, 50, None, None, None, None
132
 
133
  with gr.Blocks() as demo:
134
+ gr.Markdown("# Analyse et génération de texte")
135
 
136
  with gr.Accordion("Sélection du modèle"):
137
  model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
 
144
  top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
145
 
146
  input_text = gr.Textbox(label="Texte d'entrée", lines=3)
147
+ analyze_button = gr.Button("Analyser le prochain token")
148
+ generate_button = gr.Button("Générer la suite du texte")
149
 
150
+ next_token_probs = gr.Textbox(label="Probabilités du prochain token")
151
 
152
  with gr.Row():
153
  attention_plot = gr.Plot(label="Visualisation de l'attention")
154
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
155
 
156
+ generated_text = gr.Textbox(label="Texte généré", lines=5)
157
+
158
  reset_button = gr.Button("Réinitialiser")
159
 
160
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
161
+ analyze_button.click(analyze_next_token,
162
+ inputs=[input_text, temperature, top_p, top_k],
163
+ outputs=[next_token_probs, attention_plot, prob_plot])
164
  generate_button.click(generate_text,
165
  inputs=[input_text, temperature, top_p, top_k],
166
+ outputs=[generated_text])
167
  reset_button.click(reset,
168
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
169
 
170
  if __name__ == "__main__":
171
  demo.launch()