Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -25,24 +25,26 @@ tokenizer = None
|
|
25 |
def load_model(model_name):
|
26 |
global model, tokenizer
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
28 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
|
29 |
return f"Modèle {model_name} chargé avec succès sur CPU."
|
30 |
|
31 |
@spaces.GPU(duration=300)
|
32 |
def generate_text(input_text, temperature, top_p, top_k):
|
33 |
global model, tokenizer
|
34 |
|
35 |
-
inputs = tokenizer(input_text, return_tensors="pt")
|
36 |
input_ids = inputs["input_ids"]
|
|
|
37 |
|
38 |
with torch.no_grad():
|
39 |
outputs = model.generate(
|
40 |
input_ids,
|
|
|
41 |
max_new_tokens=50,
|
42 |
temperature=temperature,
|
43 |
top_p=top_p,
|
44 |
top_k=top_k,
|
45 |
-
output_attentions=
|
46 |
return_dict_in_generate=True
|
47 |
)
|
48 |
|
@@ -62,41 +64,41 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
62 |
# Préparer les données pour le graphique des probabilités
|
63 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
64 |
|
65 |
-
#
|
66 |
-
attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
|
67 |
-
|
68 |
-
# Préparer les données pour la carte d'attention
|
69 |
-
tokens = tokenizer.convert_ids_to_tokens(outputs.sequences[0])
|
70 |
attention_data = {
|
71 |
-
'attention':
|
72 |
-
'tokens':
|
73 |
}
|
74 |
|
75 |
-
return generated_text, attention_data, prob_data
|
76 |
|
77 |
def plot_attention(attention_data):
|
78 |
attention = np.array(attention_data['attention'])
|
79 |
tokens = attention_data['tokens']
|
80 |
|
81 |
-
plt.
|
82 |
-
|
83 |
-
plt.colorbar()
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
|
89 |
def plot_probabilities(prob_data):
|
90 |
words = list(prob_data.keys())
|
91 |
probs = list(prob_data.values())
|
92 |
|
93 |
-
plt.
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
plt.xticks(rotation=45)
|
99 |
-
|
|
|
100 |
|
101 |
def reset():
|
102 |
return "", 1.0, 1.0, 50, None, None, None
|
@@ -131,8 +133,5 @@ with gr.Blocks() as demo:
|
|
131 |
outputs=[output_text, attention_plot, prob_plot])
|
132 |
reset_button.click(reset,
|
133 |
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
|
134 |
-
|
135 |
-
attention_plot.change(plot_attention, inputs=[attention_plot], outputs=[attention_plot])
|
136 |
-
prob_plot.change(plot_probabilities, inputs=[prob_plot], outputs=[prob_plot])
|
137 |
|
138 |
demo.launch()
|
|
|
25 |
def load_model(model_name):
|
26 |
global model, tokenizer
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
28 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", attn_implementation="eager")
|
29 |
return f"Modèle {model_name} chargé avec succès sur CPU."
|
30 |
|
31 |
@spaces.GPU(duration=300)
|
32 |
def generate_text(input_text, temperature, top_p, top_k):
|
33 |
global model, tokenizer
|
34 |
|
35 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True)
|
36 |
input_ids = inputs["input_ids"]
|
37 |
+
attention_mask = inputs["attention_mask"]
|
38 |
|
39 |
with torch.no_grad():
|
40 |
outputs = model.generate(
|
41 |
input_ids,
|
42 |
+
attention_mask=attention_mask,
|
43 |
max_new_tokens=50,
|
44 |
temperature=temperature,
|
45 |
top_p=top_p,
|
46 |
top_k=top_k,
|
47 |
+
output_attentions=False,
|
48 |
return_dict_in_generate=True
|
49 |
)
|
50 |
|
|
|
64 |
# Préparer les données pour le graphique des probabilités
|
65 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
66 |
|
67 |
+
# Créer une matrice d'attention factice
|
|
|
|
|
|
|
|
|
68 |
attention_data = {
|
69 |
+
'attention': np.random.rand(len(input_ids[0]), len(input_ids[0])).tolist(),
|
70 |
+
'tokens': tokenizer.convert_ids_to_tokens(input_ids[0])
|
71 |
}
|
72 |
|
73 |
+
return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
|
74 |
|
75 |
def plot_attention(attention_data):
|
76 |
attention = np.array(attention_data['attention'])
|
77 |
tokens = attention_data['tokens']
|
78 |
|
79 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
80 |
+
im = ax.imshow(attention, cmap='viridis')
|
81 |
+
plt.colorbar(im)
|
82 |
+
ax.set_xticks(range(len(tokens)))
|
83 |
+
ax.set_yticks(range(len(tokens)))
|
84 |
+
ax.set_xticklabels(tokens, rotation=90)
|
85 |
+
ax.set_yticklabels(tokens)
|
86 |
+
ax.set_title("Carte d'attention")
|
87 |
+
plt.tight_layout()
|
88 |
+
return fig
|
89 |
|
90 |
def plot_probabilities(prob_data):
|
91 |
words = list(prob_data.keys())
|
92 |
probs = list(prob_data.values())
|
93 |
|
94 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
95 |
+
ax.bar(words, probs)
|
96 |
+
ax.set_title("Probabilités des tokens suivants les plus probables")
|
97 |
+
ax.set_xlabel("Tokens")
|
98 |
+
ax.set_ylabel("Probabilité")
|
99 |
plt.xticks(rotation=45)
|
100 |
+
plt.tight_layout()
|
101 |
+
return fig
|
102 |
|
103 |
def reset():
|
104 |
return "", 1.0, 1.0, 50, None, None, None
|
|
|
133 |
outputs=[output_text, attention_plot, prob_plot])
|
134 |
reset_button.click(reset,
|
135 |
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
|
|
|
|
|
|
|
136 |
|
137 |
demo.launch()
|