Woziii commited on
Commit
391d3d3
·
verified ·
1 Parent(s): 3c28324

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -42,8 +42,8 @@ def load_model(model_name, progress=gr.Progress()):
42
  elif i == 75:
43
  model = AutoModelForCausalLM.from_pretrained(
44
  model_name,
45
- torch_dtype=torch.bfloat16,
46
- device_map="auto",
47
  attn_implementation="eager"
48
  )
49
  if tokenizer.pad_token is None:
@@ -58,7 +58,7 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
58
  if model is None or tokenizer is None:
59
  return "Veuillez d'abord charger un modèle.", None, None
60
 
61
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
62
 
63
  try:
64
  with torch.no_grad():
@@ -74,7 +74,6 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
74
 
75
  prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
76
 
77
- # Alternative pour le mécanisme d'attention
78
  attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
79
 
80
  return prob_text, attention_heatmap, prob_plot
@@ -87,7 +86,7 @@ def generate_text(input_text, temperature, top_p, top_k):
87
  if model is None or tokenizer is None:
88
  return "Veuillez d'abord charger un modèle."
89
 
90
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
91
 
92
  try:
93
  with torch.no_grad():
@@ -124,7 +123,7 @@ def plot_attention_alternative(input_ids, last_token_logits):
124
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
125
 
126
  fig, ax = plt.subplots(figsize=(12, 6))
127
- sns.heatmap(top_attention_scores.unsqueeze(0), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
128
  ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
129
  ax.set_yticklabels(["Attention"], rotation=0)
130
  ax.set_title("Scores d'attention pour les derniers tokens")
 
42
  elif i == 75:
43
  model = AutoModelForCausalLM.from_pretrained(
44
  model_name,
45
+ torch_dtype=torch.float32,
46
+ device_map="cpu",
47
  attn_implementation="eager"
48
  )
49
  if tokenizer.pad_token is None:
 
58
  if model is None or tokenizer is None:
59
  return "Veuillez d'abord charger un modèle.", None, None
60
 
61
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
62
 
63
  try:
64
  with torch.no_grad():
 
74
 
75
  prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
76
 
 
77
  attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
78
 
79
  return prob_text, attention_heatmap, prob_plot
 
86
  if model is None or tokenizer is None:
87
  return "Veuillez d'abord charger un modèle."
88
 
89
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
90
 
91
  try:
92
  with torch.no_grad():
 
123
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
124
 
125
  fig, ax = plt.subplots(figsize=(12, 6))
126
+ sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
127
  ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
128
  ax.set_yticklabels(["Attention"], rotation=0)
129
  ax.set_title("Scores d'attention pour les derniers tokens")