Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import spaces | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from huggingface_hub import login | |
import os | |
login(token=os.environ["HF_TOKEN"]) | |
# Liste des modèles | |
models = [ | |
"meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b", | |
"meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B", | |
"mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3", | |
"google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b", | |
"croissantllm/CroissantLLMBase" | |
] | |
# Variables globales pour stocker le modèle et le tokenizer | |
model = None | |
tokenizer = None | |
def load_model(model_name): | |
global model, tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return f"Modèle {model_name} chargé avec succès." | |
def generate_text(input_text, temperature, top_p, top_k): | |
global model, tokenizer | |
inputs = tokenizer(input_text, return_tensors="pt").to("cuda") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
output_attentions=True, | |
return_dict_in_generate=True | |
) | |
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) | |
# Extraire les attentions et les logits | |
attentions = outputs.attentions[-1][0][-1].cpu().numpy() | |
logits = outputs.scores[-1][0].cpu() | |
# Visualiser l'attention | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(attentions, cmap='viridis') | |
plt.title("Carte d'attention") | |
attention_plot = plt.gcf() | |
plt.close() | |
# Obtenir les mots les plus probables | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
top_probs, top_indices = torch.topk(probs, k=5) | |
top_words = [tokenizer.decode([idx]) for idx in top_indices] | |
return generated_text, attention_plot, top_words | |
def reset(): | |
return "", 1.0, 1.0, 50, None, None, None | |
with gr.Blocks() as demo: | |
gr.Markdown("# Générateur de texte avec visualisation d'attention") | |
with gr.Accordion("Sélection du modèle"): | |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle") | |
load_button = gr.Button("Charger le modèle") | |
load_output = gr.Textbox(label="Statut du chargement") | |
with gr.Row(): | |
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température") | |
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") | |
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
input_text = gr.Textbox(label="Texte d'entrée") | |
generate_button = gr.Button("Générer") | |
output_text = gr.Textbox(label="Texte généré") | |
with gr.Row(): | |
attention_plot = gr.Plot(label="Visualisation de l'attention") | |
top_words = gr.JSON(label="Mots les plus probables") | |
reset_button = gr.Button("Réinitialiser") | |
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output]) | |
generate_button.click(generate_text, | |
inputs=[input_text, temperature, top_p, top_k], | |
outputs=[output_text, attention_plot, top_words]) | |
reset_button.click(reset, | |
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, top_words]) | |
demo.launch() | |