Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -32,10 +32,11 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
32 |
global model, tokenizer
|
33 |
|
34 |
inputs = tokenizer(input_text, return_tensors="pt")
|
|
|
35 |
|
36 |
with torch.no_grad():
|
37 |
outputs = model.generate(
|
38 |
-
|
39 |
max_new_tokens=50,
|
40 |
temperature=temperature,
|
41 |
top_p=top_p,
|
@@ -46,9 +47,11 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
46 |
|
47 |
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
48 |
|
49 |
-
#
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
|
53 |
# Visualiser l'attention
|
54 |
plt.figure(figsize=(10, 10))
|
@@ -58,12 +61,13 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
58 |
plt.close()
|
59 |
|
60 |
# Obtenir les mots les plus probables
|
61 |
-
probs = torch.nn.functional.softmax(
|
62 |
-
top_probs, top_indices = torch.topk(probs, k=5)
|
63 |
-
top_words = [tokenizer.decode([idx]) for idx in top_indices]
|
64 |
|
65 |
return generated_text, attention_plot, top_words
|
66 |
|
|
|
67 |
def reset():
|
68 |
return "", 1.0, 1.0, 50, None, None, None
|
69 |
|
|
|
32 |
global model, tokenizer
|
33 |
|
34 |
inputs = tokenizer(input_text, return_tensors="pt")
|
35 |
+
input_ids = inputs["input_ids"]
|
36 |
|
37 |
with torch.no_grad():
|
38 |
outputs = model.generate(
|
39 |
+
input_ids,
|
40 |
max_new_tokens=50,
|
41 |
temperature=temperature,
|
42 |
top_p=top_p,
|
|
|
47 |
|
48 |
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
49 |
|
50 |
+
# Obtenir les logits pour le dernier token généré
|
51 |
+
last_token_logits = model(outputs.sequences[:, -1:]).logits[:, -1, :]
|
52 |
+
|
53 |
+
# Extraire les attentions
|
54 |
+
attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
|
55 |
|
56 |
# Visualiser l'attention
|
57 |
plt.figure(figsize=(10, 10))
|
|
|
61 |
plt.close()
|
62 |
|
63 |
# Obtenir les mots les plus probables
|
64 |
+
probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
65 |
+
top_probs, top_indices = torch.topk(probs[0], k=5)
|
66 |
+
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
|
67 |
|
68 |
return generated_text, attention_plot, top_words
|
69 |
|
70 |
+
|
71 |
def reset():
|
72 |
return "", 1.0, 1.0, 50, None, None, None
|
73 |
|