Woziii commited on
Commit
bdd35f2
1 Parent(s): 6c84a6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -36
app.py CHANGED
@@ -32,48 +32,68 @@ tokenizer = None
32
 
33
  def load_model(model_name):
34
  global model, tokenizer
35
- tokenizer = AutoTokenizer.from_pretrained(model_name)
36
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
37
- if tokenizer.pad_token is None:
38
- tokenizer.pad_token = tokenizer.eos_token
39
- return f"Modèle {model_name} chargé avec succès."
 
 
 
40
 
41
  def generate_text(input_text, temperature, top_p, top_k):
42
  global model, tokenizer
43
 
 
 
 
44
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
45
 
46
- with torch.no_grad():
47
- outputs = model.generate(
48
- **inputs,
49
- max_new_tokens=50,
50
- temperature=temperature,
51
- top_p=top_p,
52
- top_k=top_k,
53
- output_attentions=True,
54
- return_dict_in_generate=True
55
- )
56
-
57
- generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
58
-
59
- # Obtenir les logits pour le dernier token généré
60
- last_token_logits = outputs.scores[-1][0]
61
-
62
- # Appliquer softmax pour obtenir les probabilités
63
- probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
64
-
65
- # Obtenir les top 5 tokens les plus probables
66
- top_k = 5
67
- top_probs, top_indices = torch.topk(probabilities, top_k)
68
- top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
69
-
70
- # Préparer les données pour le graphique des probabilités
71
- prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
72
-
73
- # Extraire les attentions (moyenne sur toutes les couches et têtes d'attention)
74
- attentions = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
75
-
76
- return generated_text, plot_attention(attentions, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])), plot_probabilities(prob_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def plot_attention(attention, tokens):
79
  fig, ax = plt.subplots(figsize=(10, 10))
@@ -101,6 +121,9 @@ def plot_probabilities(prob_data):
101
  return fig
102
 
103
  def reset():
 
 
 
104
  return "", 1.0, 1.0, 50, None, None, None
105
 
106
  with gr.Blocks() as demo:
 
32
 
33
  def load_model(model_name):
34
  global model, tokenizer
35
+ try:
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager")
38
+ if tokenizer.pad_token is None:
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+ return f"Modèle {model_name} chargé avec succès."
41
+ except Exception as e:
42
+ return f"Erreur lors du chargement du modèle : {str(e)}"
43
 
44
  def generate_text(input_text, temperature, top_p, top_k):
45
  global model, tokenizer
46
 
47
+ if model is None or tokenizer is None:
48
+ return "Veuillez d'abord charger un modèle.", None, None
49
+
50
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
51
 
52
+ try:
53
+ with torch.no_grad():
54
+ outputs = model.generate(
55
+ **inputs,
56
+ max_new_tokens=50,
57
+ temperature=temperature,
58
+ top_p=top_p,
59
+ top_k=top_k,
60
+ output_attentions=True,
61
+ return_dict_in_generate=True,
62
+ output_scores=True
63
+ )
64
+
65
+ generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
66
+
67
+ # Obtenir les logits pour le dernier token généré
68
+ if outputs.scores:
69
+ last_token_logits = outputs.scores[-1][0]
70
+
71
+ # Appliquer softmax pour obtenir les probabilités
72
+ probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
73
+
74
+ # Obtenir les top 5 tokens les plus probables
75
+ top_k = 5
76
+ top_probs, top_indices = torch.topk(probabilities, top_k)
77
+ top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
78
+
79
+ # Préparer les données pour le graphique des probabilités
80
+ prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
81
+
82
+ # Extraire les attentions (moyenne sur toutes les couches et têtes d'attention)
83
+ if outputs.attentions:
84
+ attentions = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
85
+ attention_plot = plot_attention(attentions, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
86
+ else:
87
+ attention_plot = None
88
+
89
+ prob_plot = plot_probabilities(prob_data)
90
+ else:
91
+ attention_plot = None
92
+ prob_plot = None
93
+
94
+ return generated_text, attention_plot, prob_plot
95
+ except Exception as e:
96
+ return f"Erreur lors de la génération : {str(e)}", None, None
97
 
98
  def plot_attention(attention, tokens):
99
  fig, ax = plt.subplots(figsize=(10, 10))
 
121
  return fig
122
 
123
  def reset():
124
+ global model, tokenizer
125
+ model = None
126
+ tokenizer = None
127
  return "", 1.0, 1.0, 50, None, None, None
128
 
129
  with gr.Blocks() as demo: