Woziii commited on
Commit
19de71a
1 Parent(s): 6d96117

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -117
app.py CHANGED
@@ -9,11 +9,10 @@ import os
9
  # Login to Hugging Face with token
10
  login(token=os.environ["HF_TOKEN"])
11
 
12
- # Liste des modèles
13
- model_list = [
14
- "meta-llama/Llama-2-13b",
15
- "meta-llama/Llama-2-7b",
16
- "meta-llama/Llama-2-70b",
17
  "meta-llama/Meta-Llama-3-8B",
18
  "meta-llama/Llama-3.2-3B",
19
  "meta-llama/Llama-3.1-8B",
@@ -26,126 +25,66 @@ model_list = [
26
  "croissantllm/CroissantLLMBase"
27
  ]
28
 
29
- # Charger le modèle et le tokenizer
30
- model = None
31
- tokenizer = None
32
 
 
33
  def load_model(model_name):
34
- global model, tokenizer
35
- print(f"Chargement du modèle {model_name}...")
36
-
37
- tokenizer = AutoTokenizer.from_pretrained(model_name)
38
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, attn_implementation="eager")
39
-
40
- print("Modèle chargé avec succès.")
41
- return f"Modèle {model_name} chargé."
42
 
43
- def plot_attention(attention_data):
44
- tokens = attention_data['tokens']
45
- attention = attention_data['attention']
46
-
47
- fig, ax = plt.subplots(figsize=(10, 10))
48
- cax = ax.matshow(attention, cmap='viridis')
49
- fig.colorbar(cax)
50
 
51
- ax.set_xticklabels([''] + tokens, rotation=90)
52
- ax.set_yticklabels([''] + tokens)
53
 
54
- plt.xlabel("Tokens")
55
- plt.ylabel("Tokens")
56
- plt.title("Attention Heatmap")
57
-
58
- plt.tight_layout()
59
- plt.savefig('attention_plot.png')
60
- return 'attention_plot.png'
61
 
62
- def plot_probabilities(prob_data):
63
- words, probs = zip(*prob_data.items())
64
-
65
- plt.figure(figsize=(6, 4))
66
- plt.barh(words, probs, color='skyblue')
67
- plt.xlabel('Probabilities')
68
- plt.title('Top Probable Words')
69
-
70
- plt.tight_layout()
71
- plt.savefig('probabilities_plot.png')
72
- return 'probabilities_plot.png'
73
-
74
- def generate_text(input_text, temperature, top_p, top_k):
75
- global model, tokenizer
76
-
77
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
78
-
79
- with torch.no_grad():
80
- outputs = model.generate(
81
- **inputs,
82
- max_new_tokens=50,
83
- temperature=temperature,
84
- top_p=top_p,
85
- top_k=top_k,
86
- output_scores=True,
87
- output_attentions=True,
88
- return_dict_in_generate=True,
89
- return_legacy_cache=True
90
- )
91
-
92
- generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
93
-
94
- # Logits et probabilités du dernier token généré
95
- last_token_logits = outputs.scores[-1][0]
96
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
97
-
98
- # Top 5 des mots les plus probables
99
- top_probs, top_indices = torch.topk(probabilities, 5)
100
- top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
101
-
102
- prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
103
-
104
- # Extraction des attentions
105
- attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy()
106
- attention_data = {
107
- 'attention': attentions,
108
- 'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0])
109
- }
110
-
111
- return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
112
 
113
- def reset_app():
114
- global model, tokenizer
115
- model = None
116
- tokenizer = None
117
- return "Application réinitialisée."
118
 
119
  # Interface utilisateur Gradio
120
- with gr.Blocks() as demo:
121
- with gr.Row():
122
- model_selection = gr.Accordion("Sélection du modèle", open=True)
123
- with model_selection:
124
- model_name = gr.Dropdown(choices=model_list, label="Choisir un modèle", value=model_list[0])
125
- load_model_button = gr.Button("Charger le modèle")
126
- load_status = gr.Textbox(label="Statut du modèle", interactive=False)
127
-
128
- with gr.Row():
129
- temperature = gr.Slider(0.0, 1.0, value=0.7, label="Température")
130
- top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
131
- top_k = gr.Slider(1, 100, value=50, label="Top-k")
132
-
133
- with gr.Row():
134
- input_text = gr.Textbox(label="Entrer le texte")
135
- generate_button = gr.Button("Générer")
136
-
137
- with gr.Row():
138
- output_text = gr.Textbox(label="Texte généré", interactive=False)
139
-
140
- with gr.Row():
141
- attention_plot = gr.Image(label="Carte de chaleur des attentions")
142
- prob_plot = gr.Image(label="Probabilités des mots les plus probables")
143
-
144
- with gr.Row():
145
- reset_button = gr.Button("Réinitialiser l'application")
146
-
147
- load_model_button.click(load_model, inputs=[model_name], outputs=[load_status])
148
- generate_button.click(generate_text, inputs=[input_text, temperature, top_p, top_k], outputs=[output_text, attention_plot, prob_plot])
149
- reset_button.click(reset_app)
150
 
151
- demo.launch()
 
 
9
  # Login to Hugging Face with token
10
  login(token=os.environ["HF_TOKEN"])
11
 
12
+ MODEL_LIST = [
13
+ "meta-llama/Llama-2-13b-hf",
14
+ "meta-llama/Llama-2-7b-hf",
15
+ "meta-llama/Llama-2-70b-hf",
 
16
  "meta-llama/Meta-Llama-3-8B",
17
  "meta-llama/Llama-3.2-3B",
18
  "meta-llama/Llama-3.1-8B",
 
25
  "croissantllm/CroissantLLMBase"
26
  ]
27
 
28
+ # Dictionnaire pour stocker les modèles et tokenizers déjà chargés
29
+ loaded_models = {}
 
30
 
31
+ # Charger le modèle
32
  def load_model(model_name):
33
+ if model_name not in loaded_models:
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
35
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
36
+ loaded_models[model_name] = (model, tokenizer)
37
+ return loaded_models[model_name]
 
 
 
38
 
39
+ # Génération de texte et attention
40
+ def generate_text(model_name, input_text, temperature, top_p, top_k):
41
+ model, tokenizer = load_model(model_name)
42
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
 
 
 
43
 
44
+ # Génération du texte
45
+ output = model.generate(**inputs, max_new_tokens=50, temperature=temperature, top_p=top_p, top_k=top_k, output_attentions=True)
46
 
47
+ # Décodage de la sortie
48
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
49
 
50
+ # Affichage des mots les plus probables
51
+ last_token_logits = output.scores[-1][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
53
+ top_tokens = torch.topk(probabilities, k=5)
54
+ probable_words = [tokenizer.decode([token]) for token in top_tokens.indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ return generated_text, probable_words
 
 
 
 
57
 
58
  # Interface utilisateur Gradio
59
+ def reset_interface():
60
+ return "", "", "", ""
61
+
62
+ def main():
63
+ with gr.Blocks() as app:
64
+ with gr.Accordion("Choix du modèle", open=True):
65
+ model_name = gr.Dropdown(choices=MODEL_LIST, label="Modèles disponibles", value=MODEL_LIST[0])
66
+
67
+ with gr.Row():
68
+ input_text = gr.Textbox(label="Texte d'entrée", placeholder="Saisissez votre texte ici...")
69
+
70
+ with gr.Accordion("Paramètres", open=True):
71
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.01, label="Température")
72
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="Top_p")
73
+ top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top_k")
74
+
75
+ with gr.Row():
76
+ generate_button = gr.Button("Lancer la génération")
77
+ reset_button = gr.Button("Réinitialiser")
78
+
79
+ generated_text_output = gr.Textbox(label="Texte généré", placeholder="Le texte généré s'affichera ici...")
80
+ probable_words_output = gr.Textbox(label="Mots les plus probables", placeholder="Les mots les plus probables apparaîtront ici...")
81
+
82
+ # Lancer la génération
83
+ generate_button.click(generate_text, inputs=[model_name, input_text, temperature, top_p, top_k], outputs=[generated_text_output, probable_words_output])
84
+ # Réinitialiser
85
+ reset_button.click(reset_interface, outputs=[input_text, generated_text_output, probable_words_output])
86
+
87
+ app.launch()
 
88
 
89
+ if __name__ == "__main__":
90
+ main()