Woziii commited on
Commit
6696db2
1 Parent(s): a73e468

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -74
app.py CHANGED
@@ -1,38 +1,75 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- from huggingface_hub import login
7
- import os
8
 
9
- # Authentification Hugging Face avec ton token d'accès
10
- login(token=os.environ["HF_TOKEN"])
 
11
 
12
- # Liste des modèles disponibles
13
- models = [
14
- "meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
15
- "meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
16
- "mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
17
- "google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
 
 
 
 
 
 
 
 
18
  "croissantllm/CroissantLLMBase"
19
  ]
20
 
21
- # Variables pour le modèle et le tokenizer
22
  model = None
23
  tokenizer = None
24
 
25
  def load_model(model_name):
26
  global model, tokenizer
 
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
29
 
30
- # Assurer que le token de padding est défini si nécessaire
31
- if tokenizer.pad_token is None:
32
- tokenizer.pad_token = tokenizer.eos_token
33
- model.config.pad_token_id = model.config.eos_token_id
 
 
34
 
35
- return f"Modèle {model_name} chargé avec succès sur GPU."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def generate_text(input_text, temperature, top_p, top_k):
38
  global model, tokenizer
@@ -46,8 +83,10 @@ def generate_text(input_text, temperature, top_p, top_k):
46
  temperature=temperature,
47
  top_p=top_p,
48
  top_k=top_k,
 
49
  output_attentions=True,
50
- return_dict_in_generate=True
 
51
  )
52
 
53
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
@@ -71,69 +110,42 @@ def generate_text(input_text, temperature, top_p, top_k):
71
 
72
  return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
73
 
74
- def plot_attention(attention_data):
75
- attention = attention_data['attention']
76
- tokens = attention_data['tokens']
77
-
78
- fig, ax = plt.subplots(figsize=(10, 10))
79
- im = ax.imshow(attention, cmap='viridis')
80
- plt.colorbar(im)
81
- ax.set_xticks(range(len(tokens)))
82
- ax.set_yticks(range(len(tokens)))
83
- ax.set_xticklabels(tokens, rotation=90)
84
- ax.set_yticklabels(tokens)
85
- ax.set_title("Carte d'attention")
86
- plt.tight_layout()
87
- return fig
88
-
89
- def plot_probabilities(prob_data):
90
- words = list(prob_data.keys())
91
- probs = list(prob_data.values())
92
-
93
- fig, ax = plt.subplots(figsize=(10, 5))
94
- ax.bar(words, probs)
95
- ax.set_title("Probabilités des tokens suivants les plus probables")
96
- ax.set_xlabel("Tokens")
97
- ax.set_ylabel("Probabilité")
98
- plt.xticks(rotation=45)
99
- plt.tight_layout()
100
- return fig
101
-
102
- def reset():
103
- return "", 1.0, 1.0, 50, None, None, None
104
 
105
- # Interface Gradio
106
  with gr.Blocks() as demo:
107
- gr.Markdown("# Générateur de texte avec visualisation d'attention")
 
 
 
 
 
108
 
109
- with gr.Accordion("Sélection du modèle"):
110
- model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
111
- load_button = gr.Button("Charger le modèle")
112
- load_output = gr.Textbox(label="Statut du chargement")
113
 
114
  with gr.Row():
115
- temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
116
- top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
117
- top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
118
 
119
- input_text = gr.Textbox(label="Texte d'entrée")
120
- generate_button = gr.Button("Générer")
121
 
122
- output_text = gr.Textbox(label="Texte généré")
 
 
123
 
124
  with gr.Row():
125
- attention_plot = gr.Plot(label="Visualisation de l'attention")
126
- prob_plot = gr.Plot(label="Probabilités des tokens suivants")
127
-
128
- reset_button = gr.Button("Réinitialiser")
129
-
130
- # Association des actions avec les boutons
131
- load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
132
- generate_button.click(generate_text,
133
- inputs=[input_text, temperature, top_p, top_k],
134
- outputs=[output_text, attention_plot, prob_plot])
135
- reset_button.click(reset,
136
- outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
137
 
138
- # Lancement de l'application
139
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from huggingface_hub import login
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
 
7
 
8
+ # Login to Hugging Face with token
9
+ HF_TOKEN = "hf_token" # Remplacer par ton token Hugging Face
10
+ login(HF_TOKEN)
11
 
12
+ # Liste des modèles
13
+ model_list = [
14
+ "meta-llama/Llama-2-13b",
15
+ "meta-llama/Llama-2-7b",
16
+ "meta-llama/Llama-2-70b",
17
+ "meta-llama/Meta-Llama-3-8B",
18
+ "meta-llama/Llama-3.2-3B",
19
+ "meta-llama/Llama-3.1-8B",
20
+ "mistralai/Mistral-7B-v0.1",
21
+ "mistralai/Mixtral-8x7B-v0.1",
22
+ "mistralai/Mistral-7B-v0.3",
23
+ "google/gemma-2-2b",
24
+ "google/gemma-2-9b",
25
+ "google/gemma-2-27b",
26
  "croissantllm/CroissantLLMBase"
27
  ]
28
 
29
+ # Charger le modèle et le tokenizer
30
  model = None
31
  tokenizer = None
32
 
33
  def load_model(model_name):
34
  global model, tokenizer
35
+ print(f"Chargement du modèle {model_name}...")
36
+
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, attn_implementation="eager")
39
 
40
+ print("Modèle chargé avec succès.")
41
+ return f"Modèle {model_name} chargé."
42
+
43
+ def plot_attention(attention_data):
44
+ tokens = attention_data['tokens']
45
+ attention = attention_data['attention']
46
 
47
+ fig, ax = plt.subplots(figsize=(10, 10))
48
+ cax = ax.matshow(attention, cmap='viridis')
49
+ fig.colorbar(cax)
50
+
51
+ ax.set_xticklabels([''] + tokens, rotation=90)
52
+ ax.set_yticklabels([''] + tokens)
53
+
54
+ plt.xlabel("Tokens")
55
+ plt.ylabel("Tokens")
56
+ plt.title("Attention Heatmap")
57
+
58
+ plt.tight_layout()
59
+ plt.savefig('attention_plot.png')
60
+ return 'attention_plot.png'
61
+
62
+ def plot_probabilities(prob_data):
63
+ words, probs = zip(*prob_data.items())
64
+
65
+ plt.figure(figsize=(6, 4))
66
+ plt.barh(words, probs, color='skyblue')
67
+ plt.xlabel('Probabilities')
68
+ plt.title('Top Probable Words')
69
+
70
+ plt.tight_layout()
71
+ plt.savefig('probabilities_plot.png')
72
+ return 'probabilities_plot.png'
73
 
74
  def generate_text(input_text, temperature, top_p, top_k):
75
  global model, tokenizer
 
83
  temperature=temperature,
84
  top_p=top_p,
85
  top_k=top_k,
86
+ output_scores=True,
87
  output_attentions=True,
88
+ return_dict_in_generate=True,
89
+ return_legacy_cache=True
90
  )
91
 
92
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
 
110
 
111
  return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
112
 
113
+ def reset_app():
114
+ global model, tokenizer
115
+ model = None
116
+ tokenizer = None
117
+ return "Application réinitialisée."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # Interface utilisateur Gradio
120
  with gr.Blocks() as demo:
121
+ with gr.Row():
122
+ model_selection = gr.Accordion("Sélection du modèle", open=True)
123
+ with model_selection:
124
+ model_name = gr.Dropdown(choices=model_list, label="Choisir un modèle", value=model_list[0])
125
+ load_model_button = gr.Button("Charger le modèle")
126
+ load_status = gr.Textbox(label="Statut du modèle", interactive=False)
127
 
128
+ with gr.Row():
129
+ temperature = gr.Slider(0.0, 1.0, value=0.7, label="Température")
130
+ top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
131
+ top_k = gr.Slider(1, 100, value=50, label="Top-k")
132
 
133
  with gr.Row():
134
+ input_text = gr.Textbox(label="Entrer le texte")
135
+ generate_button = gr.Button("Générer")
 
136
 
137
+ with gr.Row():
138
+ output_text = gr.Textbox(label="Texte généré", interactive=False)
139
 
140
+ with gr.Row():
141
+ attention_plot = gr.Image(label="Carte de chaleur des attentions")
142
+ prob_plot = gr.Image(label="Probabilités des mots les plus probables")
143
 
144
  with gr.Row():
145
+ reset_button = gr.Button("Réinitialiser l'application")
146
+
147
+ load_model_button.click(load_model, inputs=[model_name], outputs=[load_status])
148
+ generate_button.click(generate_text, inputs=[input_text, temperature, top_p, top_k], outputs=[output_text, attention_plot, prob_plot])
149
+ reset_button.click(reset_app)
 
 
 
 
 
 
 
150
 
 
151
  demo.launch()