LLMnBiasV2 / app.py
Woziii's picture
Update app.py
72dbc28 verified
raw
history blame
3.6 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login
import os
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
models = [
"meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
"meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Variables globales pour stocker le modèle et le tokenizer
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return f"Modèle {model_name} chargé avec succès."
@spaces.GPU(duration=300)
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_attentions=True,
return_dict_in_generate=True
)
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Extraire les attentions et les logits
attentions = outputs.attentions[-1][0][-1].cpu().numpy()
logits = outputs.scores[-1][0].cpu()
# Visualiser l'attention
plt.figure(figsize=(10, 10))
plt.imshow(attentions, cmap='viridis')
plt.title("Carte d'attention")
attention_plot = plt.gcf()
plt.close()
# Obtenir les mots les plus probables
probs = torch.nn.functional.softmax(logits, dim=-1)
top_probs, top_indices = torch.topk(probs, k=5)
top_words = [tokenizer.decode([idx]) for idx in top_indices]
return generated_text, attention_plot, top_words
def reset():
return "", 1.0, 1.0, 50, None, None, None
with gr.Blocks() as demo:
gr.Markdown("# Générateur de texte avec visualisation d'attention")
with gr.Accordion("Sélection du modèle"):
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée")
generate_button = gr.Button("Générer")
output_text = gr.Textbox(label="Texte généré")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
top_words = gr.JSON(label="Mots les plus probables")
reset_button = gr.Button("Réinitialiser")
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[output_text, attention_plot, top_words])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, top_words])
demo.launch()