Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -42,8 +42,8 @@ def load_model(model_name, progress=gr.Progress()):
|
|
42 |
elif i == 75:
|
43 |
model = AutoModelForCausalLM.from_pretrained(
|
44 |
model_name,
|
45 |
-
torch_dtype=torch.
|
46 |
-
device_map="
|
47 |
attn_implementation="eager"
|
48 |
)
|
49 |
if tokenizer.pad_token is None:
|
@@ -58,7 +58,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
58 |
if model is None or tokenizer is None:
|
59 |
return "Veuillez d'abord charger un modèle.", None, None
|
60 |
|
61 |
-
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
62 |
|
63 |
try:
|
64 |
with torch.no_grad():
|
@@ -74,7 +74,6 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
74 |
|
75 |
prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
|
76 |
|
77 |
-
# Alternative pour le mécanisme d'attention
|
78 |
attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
|
79 |
|
80 |
return prob_text, attention_heatmap, prob_plot
|
@@ -87,7 +86,7 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
87 |
if model is None or tokenizer is None:
|
88 |
return "Veuillez d'abord charger un modèle."
|
89 |
|
90 |
-
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
91 |
|
92 |
try:
|
93 |
with torch.no_grad():
|
@@ -124,7 +123,7 @@ def plot_attention_alternative(input_ids, last_token_logits):
|
|
124 |
top_attention_scores, _ = torch.topk(attention_scores, top_k)
|
125 |
|
126 |
fig, ax = plt.subplots(figsize=(12, 6))
|
127 |
-
sns.heatmap(top_attention_scores.unsqueeze(0), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
|
128 |
ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
|
129 |
ax.set_yticklabels(["Attention"], rotation=0)
|
130 |
ax.set_title("Scores d'attention pour les derniers tokens")
|
|
|
42 |
elif i == 75:
|
43 |
model = AutoModelForCausalLM.from_pretrained(
|
44 |
model_name,
|
45 |
+
torch_dtype=torch.float32,
|
46 |
+
device_map="cpu",
|
47 |
attn_implementation="eager"
|
48 |
)
|
49 |
if tokenizer.pad_token is None:
|
|
|
58 |
if model is None or tokenizer is None:
|
59 |
return "Veuillez d'abord charger un modèle.", None, None
|
60 |
|
61 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
62 |
|
63 |
try:
|
64 |
with torch.no_grad():
|
|
|
74 |
|
75 |
prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
|
76 |
|
|
|
77 |
attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
|
78 |
|
79 |
return prob_text, attention_heatmap, prob_plot
|
|
|
86 |
if model is None or tokenizer is None:
|
87 |
return "Veuillez d'abord charger un modèle."
|
88 |
|
89 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
90 |
|
91 |
try:
|
92 |
with torch.no_grad():
|
|
|
123 |
top_attention_scores, _ = torch.topk(attention_scores, top_k)
|
124 |
|
125 |
fig, ax = plt.subplots(figsize=(12, 6))
|
126 |
+
sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
|
127 |
ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
|
128 |
ax.set_yticklabels(["Attention"], rotation=0)
|
129 |
ax.set_title("Scores d'attention pour les derniers tokens")
|