Woziii commited on
Commit
0c7cad3
1 Parent(s): bdd35f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -21
app.py CHANGED
@@ -34,7 +34,12 @@ def load_model(model_name):
34
  global model, tokenizer
35
  try:
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager")
 
 
 
 
 
38
  if tokenizer.pad_token is None:
39
  tokenizer.pad_token = tokenizer.eos_token
40
  return f"Modèle {model_name} chargé avec succès."
@@ -64,33 +69,23 @@ def generate_text(input_text, temperature, top_p, top_k):
64
 
65
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
66
 
67
- # Obtenir les logits pour le dernier token généré
68
- if outputs.scores:
69
  last_token_logits = outputs.scores[-1][0]
70
-
71
- # Appliquer softmax pour obtenir les probabilités
72
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
73
-
74
- # Obtenir les top 5 tokens les plus probables
75
  top_k = 5
76
  top_probs, top_indices = torch.topk(probabilities, top_k)
77
  top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
78
-
79
- # Préparer les données pour le graphique des probabilités
80
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
81
-
82
- # Extraire les attentions (moyenne sur toutes les couches et têtes d'attention)
83
- if outputs.attentions:
84
- attentions = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
85
- attention_plot = plot_attention(attentions, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
86
- else:
87
- attention_plot = None
88
-
89
  prob_plot = plot_probabilities(prob_data)
90
  else:
91
- attention_plot = None
92
  prob_plot = None
93
 
 
 
 
 
 
 
94
  return generated_text, attention_plot, prob_plot
95
  except Exception as e:
96
  return f"Erreur lors de la génération : {str(e)}", None, None
@@ -139,10 +134,10 @@ with gr.Blocks() as demo:
139
  top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
140
  top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
141
 
142
- input_text = gr.Textbox(label="Texte d'entrée")
143
  generate_button = gr.Button("Générer")
144
 
145
- output_text = gr.Textbox(label="Texte généré")
146
 
147
  with gr.Row():
148
  attention_plot = gr.Plot(label="Visualisation de l'attention")
@@ -157,4 +152,5 @@ with gr.Blocks() as demo:
157
  reset_button.click(reset,
158
  outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
159
 
160
- demo.launch()
 
 
34
  global model, tokenizer
35
  try:
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ model_name,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ attn_implementation="eager"
42
+ )
43
  if tokenizer.pad_token is None:
44
  tokenizer.pad_token = tokenizer.eos_token
45
  return f"Modèle {model_name} chargé avec succès."
 
69
 
70
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
71
 
72
+ if hasattr(outputs, 'scores') and outputs.scores:
 
73
  last_token_logits = outputs.scores[-1][0]
 
 
74
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
 
 
75
  top_k = 5
76
  top_probs, top_indices = torch.topk(probabilities, top_k)
77
  top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
 
 
78
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
 
 
 
 
 
 
 
 
79
  prob_plot = plot_probabilities(prob_data)
80
  else:
 
81
  prob_plot = None
82
 
83
+ if hasattr(outputs, 'attentions') and outputs.attentions:
84
+ attention_data = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
85
+ attention_plot = plot_attention(attention_data, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
86
+ else:
87
+ attention_plot = None
88
+
89
  return generated_text, attention_plot, prob_plot
90
  except Exception as e:
91
  return f"Erreur lors de la génération : {str(e)}", None, None
 
134
  top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
135
  top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
136
 
137
+ input_text = gr.Textbox(label="Texte d'entrée", lines=3)
138
  generate_button = gr.Button("Générer")
139
 
140
+ output_text = gr.Textbox(label="Texte généré", lines=5)
141
 
142
  with gr.Row():
143
  attention_plot = gr.Plot(label="Visualisation de l'attention")
 
152
  reset_button.click(reset,
153
  outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
154
 
155
+ if __name__ == "__main__":
156
+ demo.launch()