Woziii commited on
Commit
33f0de1
1 Parent(s): 35079fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -58
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
- import os
8
 
9
- # Login to Hugging Face with token
10
  login(token=os.environ["HF_TOKEN"])
11
 
12
- MODEL_LIST = [
 
13
  "meta-llama/Llama-2-13b-hf",
14
  "meta-llama/Llama-2-7b-hf",
15
  "meta-llama/Llama-2-70b-hf",
@@ -25,66 +26,112 @@ MODEL_LIST = [
25
  "croissantllm/CroissantLLMBase"
26
  ]
27
 
28
- # Dictionnaire pour stocker les modèles et tokenizers déjà chargés
29
- loaded_models = {}
 
30
 
31
- # Charger le modèle
32
  def load_model(model_name):
33
- if model_name not in loaded_models:
34
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
35
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
36
- loaded_models[model_name] = (model, tokenizer)
37
- return loaded_models[model_name]
38
-
39
- # Génération de texte et attention
40
- def generate_text(model_name, input_text, temperature, top_p, top_k):
41
- model, tokenizer = load_model(model_name)
42
- inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
43
 
44
- # Génération du texte
45
- output = model.generate(**inputs, max_new_tokens=50, temperature=temperature, top_p=top_p, top_k=top_k, output_attentions=True)
46
 
47
- # Décodage de la sortie
48
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
49
-
50
- # Affichage des mots les plus probables
51
- last_token_logits = output.scores[-1][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
53
- top_tokens = torch.topk(probabilities, k=5)
54
- probable_words = [tokenizer.decode([token]) for token in top_tokens.indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- return generated_text, probable_words
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Interface utilisateur Gradio
59
- def reset_interface():
60
- return "", "", "", ""
61
 
62
- def main():
63
- with gr.Blocks() as app:
64
- with gr.Accordion("Choix du modèle", open=True):
65
- model_name = gr.Dropdown(choices=MODEL_LIST, label="Modèles disponibles", value=MODEL_LIST[0])
66
-
67
- with gr.Row():
68
- input_text = gr.Textbox(label="Texte d'entrée", placeholder="Saisissez votre texte ici...")
69
-
70
- with gr.Accordion("Paramètres", open=True):
71
- temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.01, label="Température")
72
- top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="Top_p")
73
- top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top_k")
74
-
75
- with gr.Row():
76
- generate_button = gr.Button("Lancer la génération")
77
- reset_button = gr.Button("Réinitialiser")
78
-
79
- generated_text_output = gr.Textbox(label="Texte généré", placeholder="Le texte généré s'affichera ici...")
80
- probable_words_output = gr.Textbox(label="Mots les plus probables", placeholder="Les mots les plus probables apparaîtront ici...")
81
-
82
- # Lancer la génération
83
- generate_button.click(generate_text, inputs=[model_name, input_text, temperature, top_p, top_k], outputs=[generated_text_output, probable_words_output])
84
- # Réinitialiser
85
- reset_button.click(reset_interface, outputs=[input_text, generated_text_output, probable_words_output])
86
-
87
- app.launch()
 
 
 
 
88
 
89
- if __name__ == "__main__":
90
- main()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from huggingface_hub import login
5
+ import os
6
  import matplotlib.pyplot as plt
7
  import numpy as np
 
8
 
9
+ # Authentification
10
  login(token=os.environ["HF_TOKEN"])
11
 
12
+ # Liste des modèles
13
+ models = [
14
  "meta-llama/Llama-2-13b-hf",
15
  "meta-llama/Llama-2-7b-hf",
16
  "meta-llama/Llama-2-70b-hf",
 
26
  "croissantllm/CroissantLLMBase"
27
  ]
28
 
29
+ # Variables globales
30
+ model = None
31
+ tokenizer = None
32
 
 
33
  def load_model(model_name):
34
+ global model, tokenizer
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
37
+ if tokenizer.pad_token is None:
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+ return f"Modèle {model_name} chargé avec succès."
 
 
 
 
40
 
41
+ def generate_text(input_text, temperature, top_p, top_k):
42
+ global model, tokenizer
43
 
44
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
45
+
46
+ with torch.no_grad():
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=50,
50
+ temperature=temperature,
51
+ top_p=top_p,
52
+ top_k=top_k,
53
+ output_attentions=True,
54
+ return_dict_in_generate=True
55
+ )
56
+
57
+ generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
58
+
59
+ # Obtenir les logits pour le dernier token généré
60
+ last_token_logits = outputs.scores[-1][0]
61
+
62
+ # Appliquer softmax pour obtenir les probabilités
63
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
64
+
65
+ # Obtenir les top 5 tokens les plus probables
66
+ top_k = 5
67
+ top_probs, top_indices = torch.topk(probabilities, top_k)
68
+ top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
69
+
70
+ # Préparer les données pour le graphique des probabilités
71
+ prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
72
+
73
+ # Extraire les attentions (moyenne sur toutes les couches et têtes d'attention)
74
+ attentions = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
75
+
76
+ return generated_text, plot_attention(attentions, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])), plot_probabilities(prob_data)
77
+
78
+ def plot_attention(attention, tokens):
79
+ fig, ax = plt.subplots(figsize=(10, 10))
80
+ im = ax.imshow(attention, cmap='viridis')
81
+ ax.set_xticks(range(len(tokens)))
82
+ ax.set_yticks(range(len(tokens)))
83
+ ax.set_xticklabels(tokens, rotation=90)
84
+ ax.set_yticklabels(tokens)
85
+ plt.colorbar(im)
86
+ plt.title("Carte d'attention")
87
+ plt.tight_layout()
88
+ return fig
89
 
90
+ def plot_probabilities(prob_data):
91
+ words = list(prob_data.keys())
92
+ probs = list(prob_data.values())
93
+
94
+ fig, ax = plt.subplots(figsize=(10, 5))
95
+ ax.bar(words, probs)
96
+ ax.set_title("Probabilités des tokens suivants les plus probables")
97
+ ax.set_xlabel("Tokens")
98
+ ax.set_ylabel("Probabilité")
99
+ plt.xticks(rotation=45)
100
+ plt.tight_layout()
101
+ return fig
102
 
103
+ def reset():
104
+ return "", 1.0, 1.0, 50, None, None, None
 
105
 
106
+ with gr.Blocks() as demo:
107
+ gr.Markdown("# Générateur de texte avec visualisation d'attention")
108
+
109
+ with gr.Accordion("Sélection du modèle"):
110
+ model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
111
+ load_button = gr.Button("Charger le modèle")
112
+ load_output = gr.Textbox(label="Statut du chargement")
113
+
114
+ with gr.Row():
115
+ temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
116
+ top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
117
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
118
+
119
+ input_text = gr.Textbox(label="Texte d'entrée")
120
+ generate_button = gr.Button("Générer")
121
+
122
+ output_text = gr.Textbox(label="Texte généré")
123
+
124
+ with gr.Row():
125
+ attention_plot = gr.Plot(label="Visualisation de l'attention")
126
+ prob_plot = gr.Plot(label="Probabilités des tokens suivants")
127
+
128
+ reset_button = gr.Button("Réinitialiser")
129
+
130
+ load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
131
+ generate_button.click(generate_text,
132
+ inputs=[input_text, temperature, top_p, top_k],
133
+ outputs=[output_text, attention_plot, prob_plot])
134
+ reset_button.click(reset,
135
+ outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
136
 
137
+ demo.launch()