LLMnBiasV2 / app.py
Woziii's picture
Update app.py
33f0de1 verified
raw
history blame
4.84 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
import matplotlib.pyplot as plt
import numpy as np
# Authentification
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
models = [
"meta-llama/Llama-2-13b-hf",
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-70b-hf",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-3.2-3B",
"meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1",
"mistralai/Mixtral-8x7B-v0.1",
"mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b",
"google/gemma-2-9b",
"google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Variables globales
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return f"Modèle {model_name} chargé avec succès."
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_attentions=True,
return_dict_in_generate=True
)
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Obtenir les logits pour le dernier token généré
last_token_logits = outputs.scores[-1][0]
# Appliquer softmax pour obtenir les probabilités
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
# Obtenir les top 5 tokens les plus probables
top_k = 5
top_probs, top_indices = torch.topk(probabilities, top_k)
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
# Préparer les données pour le graphique des probabilités
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
# Extraire les attentions (moyenne sur toutes les couches et têtes d'attention)
attentions = torch.mean(torch.stack(outputs.attentions), dim=(0, 1)).cpu().numpy()
return generated_text, plot_attention(attentions, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])), plot_probabilities(prob_data)
def plot_attention(attention, tokens):
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(attention, cmap='viridis')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
plt.colorbar(im)
plt.title("Carte d'attention")
plt.tight_layout()
return fig
def plot_probabilities(prob_data):
words = list(prob_data.keys())
probs = list(prob_data.values())
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(words, probs)
ax.set_title("Probabilités des tokens suivants les plus probables")
ax.set_xlabel("Tokens")
ax.set_ylabel("Probabilité")
plt.xticks(rotation=45)
plt.tight_layout()
return fig
def reset():
return "", 1.0, 1.0, 50, None, None, None
with gr.Blocks() as demo:
gr.Markdown("# Générateur de texte avec visualisation d'attention")
with gr.Accordion("Sélection du modèle"):
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée")
generate_button = gr.Button("Générer")
output_text = gr.Textbox(label="Texte généré")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
reset_button = gr.Button("Réinitialiser")
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[output_text, attention_plot, prob_plot])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
demo.launch()