Woziii commited on
Commit
60b53a6
1 Parent(s): 3d8c020

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import spaces
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ # Liste des modèles
9
+ models = [
10
+ "meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
11
+ "meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
12
+ "mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
13
+ "google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
14
+ "croissantllm/CroissantLLMBase"
15
+ ]
16
+
17
+ # Variables globales pour stocker le modèle et le tokenizer
18
+ model = None
19
+ tokenizer = None
20
+
21
+ def load_model(model_name):
22
+ global model, tokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = AutoModelForCausalLM.from_pretrained(model_name)
25
+ return f"Modèle {model_name} chargé avec succès."
26
+
27
+ @spaces.GPU(duration=300)
28
+ def generate_text(input_text, temperature, top_p, top_k):
29
+ global model, tokenizer
30
+
31
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
32
+
33
+ with torch.no_grad():
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_new_tokens=50,
37
+ temperature=temperature,
38
+ top_p=top_p,
39
+ top_k=top_k,
40
+ output_attentions=True,
41
+ return_dict_in_generate=True
42
+ )
43
+
44
+ generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
45
+
46
+ # Extraire les attentions et les logits
47
+ attentions = outputs.attentions[-1][0][-1].cpu().numpy()
48
+ logits = outputs.scores[-1][0].cpu()
49
+
50
+ # Visualiser l'attention
51
+ plt.figure(figsize=(10, 10))
52
+ plt.imshow(attentions, cmap='viridis')
53
+ plt.title("Carte d'attention")
54
+ attention_plot = plt.gcf()
55
+ plt.close()
56
+
57
+ # Obtenir les mots les plus probables
58
+ probs = torch.nn.functional.softmax(logits, dim=-1)
59
+ top_probs, top_indices = torch.topk(probs, k=5)
60
+ top_words = [tokenizer.decode([idx]) for idx in top_indices]
61
+
62
+ return generated_text, attention_plot, top_words
63
+
64
+ def reset():
65
+ return "", 1.0, 1.0, 50, None, None, None
66
+
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown("# Générateur de texte avec visualisation d'attention")
69
+
70
+ with gr.Accordion("Sélection du modèle"):
71
+ model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
72
+ load_button = gr.Button("Charger le modèle")
73
+ load_output = gr.Textbox(label="Statut du chargement")
74
+
75
+ with gr.Row():
76
+ temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
77
+ top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
78
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
79
+
80
+ input_text = gr.Textbox(label="Texte d'entrée")
81
+ generate_button = gr.Button("Générer")
82
+
83
+ output_text = gr.Textbox(label="Texte généré")
84
+
85
+ with gr.Row():
86
+ attention_plot = gr.Plot(label="Visualisation de l'attention")
87
+ top_words = gr.JSON(label="Mots les plus probables")
88
+
89
+ reset_button = gr.Button("Réinitialiser")
90
+
91
+ load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
92
+ generate_button.click(generate_text,
93
+ inputs=[input_text, temperature, top_p, top_k],
94
+ outputs=[output_text, attention_plot, top_words])
95
+ reset_button.click(reset,
96
+ outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, top_words])
97
+
98
+ demo.launch()