File size: 3,604 Bytes
60b53a6
 
 
 
 
 
72dbc28
 
 
60b53a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login
import os
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
models = [
    "meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
    "meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
    "mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
    "google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
    "croissantllm/CroissantLLMBase"
]

# Variables globales pour stocker le modèle et le tokenizer
model = None
tokenizer = None

def load_model(model_name):
    global model, tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return f"Modèle {model_name} chargé avec succès."

@spaces.GPU(duration=300)
def generate_text(input_text, temperature, top_p, top_k):
    global model, tokenizer
    
    inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            output_attentions=True,
            return_dict_in_generate=True
        )
    
    generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    
    # Extraire les attentions et les logits
    attentions = outputs.attentions[-1][0][-1].cpu().numpy()
    logits = outputs.scores[-1][0].cpu()
    
    # Visualiser l'attention
    plt.figure(figsize=(10, 10))
    plt.imshow(attentions, cmap='viridis')
    plt.title("Carte d'attention")
    attention_plot = plt.gcf()
    plt.close()
    
    # Obtenir les mots les plus probables
    probs = torch.nn.functional.softmax(logits, dim=-1)
    top_probs, top_indices = torch.topk(probs, k=5)
    top_words = [tokenizer.decode([idx]) for idx in top_indices]
    
    return generated_text, attention_plot, top_words

def reset():
    return "", 1.0, 1.0, 50, None, None, None

with gr.Blocks() as demo:
    gr.Markdown("# Générateur de texte avec visualisation d'attention")
    
    with gr.Accordion("Sélection du modèle"):
        model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
        load_button = gr.Button("Charger le modèle")
        load_output = gr.Textbox(label="Statut du chargement")
    
    with gr.Row():
        temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
        top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
        top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
    
    input_text = gr.Textbox(label="Texte d'entrée")
    generate_button = gr.Button("Générer")
    
    output_text = gr.Textbox(label="Texte généré")
    
    with gr.Row():
        attention_plot = gr.Plot(label="Visualisation de l'attention")
        top_words = gr.JSON(label="Mots les plus probables")
    
    reset_button = gr.Button("Réinitialiser")
    
    load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
    generate_button.click(generate_text, 
                          inputs=[input_text, temperature, top_p, top_k], 
                          outputs=[output_text, attention_plot, top_words])
    reset_button.click(reset, 
                       outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, top_words])

demo.launch()