Woziii commited on
Commit
5efe227
1 Parent(s): 9787d82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -25,24 +25,26 @@ tokenizer = None
25
  def load_model(model_name):
26
  global model, tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
29
  return f"Modèle {model_name} chargé avec succès sur CPU."
30
 
31
  @spaces.GPU(duration=300)
32
  def generate_text(input_text, temperature, top_p, top_k):
33
  global model, tokenizer
34
 
35
- inputs = tokenizer(input_text, return_tensors="pt")
36
  input_ids = inputs["input_ids"]
 
37
 
38
  with torch.no_grad():
39
  outputs = model.generate(
40
  input_ids,
 
41
  max_new_tokens=50,
42
  temperature=temperature,
43
  top_p=top_p,
44
  top_k=top_k,
45
- output_attentions=True,
46
  return_dict_in_generate=True
47
  )
48
 
@@ -62,41 +64,41 @@ def generate_text(input_text, temperature, top_p, top_k):
62
  # Préparer les données pour le graphique des probabilités
63
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
64
 
65
- # Extraire les attentions
66
- attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
67
-
68
- # Préparer les données pour la carte d'attention
69
- tokens = tokenizer.convert_ids_to_tokens(outputs.sequences[0])
70
  attention_data = {
71
- 'attention': attentions.tolist(),
72
- 'tokens': tokens
73
  }
74
 
75
- return generated_text, attention_data, prob_data
76
 
77
  def plot_attention(attention_data):
78
  attention = np.array(attention_data['attention'])
79
  tokens = attention_data['tokens']
80
 
81
- plt.figure(figsize=(10, 10))
82
- plt.imshow(attention, cmap='viridis')
83
- plt.colorbar()
84
- plt.xticks(range(len(tokens)), tokens, rotation=90)
85
- plt.yticks(range(len(tokens)), tokens)
86
- plt.title("Carte d'attention")
87
- return plt
 
 
 
88
 
89
  def plot_probabilities(prob_data):
90
  words = list(prob_data.keys())
91
  probs = list(prob_data.values())
92
 
93
- plt.figure(figsize=(10, 5))
94
- plt.bar(words, probs)
95
- plt.title("Probabilités des tokens suivants les plus probables")
96
- plt.xlabel("Tokens")
97
- plt.ylabel("Probabilité")
98
  plt.xticks(rotation=45)
99
- return plt
 
100
 
101
  def reset():
102
  return "", 1.0, 1.0, 50, None, None, None
@@ -131,8 +133,5 @@ with gr.Blocks() as demo:
131
  outputs=[output_text, attention_plot, prob_plot])
132
  reset_button.click(reset,
133
  outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
134
-
135
- attention_plot.change(plot_attention, inputs=[attention_plot], outputs=[attention_plot])
136
- prob_plot.change(plot_probabilities, inputs=[prob_plot], outputs=[prob_plot])
137
 
138
  demo.launch()
 
25
  def load_model(model_name):
26
  global model, tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", attn_implementation="eager")
29
  return f"Modèle {model_name} chargé avec succès sur CPU."
30
 
31
  @spaces.GPU(duration=300)
32
  def generate_text(input_text, temperature, top_p, top_k):
33
  global model, tokenizer
34
 
35
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True)
36
  input_ids = inputs["input_ids"]
37
+ attention_mask = inputs["attention_mask"]
38
 
39
  with torch.no_grad():
40
  outputs = model.generate(
41
  input_ids,
42
+ attention_mask=attention_mask,
43
  max_new_tokens=50,
44
  temperature=temperature,
45
  top_p=top_p,
46
  top_k=top_k,
47
+ output_attentions=False,
48
  return_dict_in_generate=True
49
  )
50
 
 
64
  # Préparer les données pour le graphique des probabilités
65
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
66
 
67
+ # Créer une matrice d'attention factice
 
 
 
 
68
  attention_data = {
69
+ 'attention': np.random.rand(len(input_ids[0]), len(input_ids[0])).tolist(),
70
+ 'tokens': tokenizer.convert_ids_to_tokens(input_ids[0])
71
  }
72
 
73
+ return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
74
 
75
  def plot_attention(attention_data):
76
  attention = np.array(attention_data['attention'])
77
  tokens = attention_data['tokens']
78
 
79
+ fig, ax = plt.subplots(figsize=(10, 10))
80
+ im = ax.imshow(attention, cmap='viridis')
81
+ plt.colorbar(im)
82
+ ax.set_xticks(range(len(tokens)))
83
+ ax.set_yticks(range(len(tokens)))
84
+ ax.set_xticklabels(tokens, rotation=90)
85
+ ax.set_yticklabels(tokens)
86
+ ax.set_title("Carte d'attention")
87
+ plt.tight_layout()
88
+ return fig
89
 
90
  def plot_probabilities(prob_data):
91
  words = list(prob_data.keys())
92
  probs = list(prob_data.values())
93
 
94
+ fig, ax = plt.subplots(figsize=(10, 5))
95
+ ax.bar(words, probs)
96
+ ax.set_title("Probabilités des tokens suivants les plus probables")
97
+ ax.set_xlabel("Tokens")
98
+ ax.set_ylabel("Probabilité")
99
  plt.xticks(rotation=45)
100
+ plt.tight_layout()
101
+ return fig
102
 
103
  def reset():
104
  return "", 1.0, 1.0, 50, None, None, None
 
133
  outputs=[output_text, attention_plot, prob_plot])
134
  reset_button.click(reset,
135
  outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
 
 
 
136
 
137
  demo.launch()