LLMnBiasV2 / app.py
Woziii's picture
Update app.py
984dc97 verified
raw
history blame
8.63 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
from langdetect import detect
# Authentification
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
models = [
"meta-llama/Llama-2-13b-hf",
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-70b-hf",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-3.2-3B",
"meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1",
"mistralai/Mixtral-8x7B-v0.1",
"mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b",
"google/gemma-2-9b",
"google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Dictionnaire des langues supportées par modèle
model_languages = {
"meta-llama/Llama-2-13b-hf": ["en"],
"meta-llama/Llama-2-7b-hf": ["en"],
"meta-llama/Llama-2-70b-hf": ["en"],
"meta-llama/Meta-Llama-3-8B": ["en"],
"meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
"meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
"mistralai/Mistral-7B-v0.1": ["en"],
"mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
"mistralai/Mistral-7B-v0.3": ["en"],
"google/gemma-2-2b": ["en"],
"google/gemma-2-9b": ["en"],
"google/gemma-2-27b": ["en"],
"croissantllm/CroissantLLMBase": ["en", "fr"]
}
# Variables globales
model = None
tokenizer = None
def load_model(model_name, progress=gr.Progress()):
global model, tokenizer
try:
progress(0, desc="Chargement du tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)
progress(0.5, desc="Chargement du modèle")
# Configuration générique pour tous les modèles
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
progress(1.0, desc="Modèle chargé")
return f"Modèle {model_name} chargé avec succès."
except Exception as e:
return f"Erreur lors du chargement du modèle : {str(e)}"
def ensure_token_display(token):
"""Assure que le token est affiché correctement."""
if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
return tokenizer.decode([int(token)])
return token
def analyze_next_token(input_text, temperature, top_p, top_k):
global model, tokenizer
if model is None or tokenizer is None:
return "Veuillez d'abord charger un modèle.", None, None
# Détection de la langue
detected_lang = detect(input_text)
if detected_lang not in model_languages.get(model.config._name_or_path, []):
return f"Langue détectée ({detected_lang}) non supportée par ce modèle.", None, None
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
try:
with torch.no_grad():
outputs = model(**inputs)
last_token_logits = outputs.logits[0, -1, :]
probabilities = torch.nn.functional.softmax(last_token_logits / temperature, dim=-1)
top_k = min(top_k, probabilities.size(-1))
top_probs, top_indices = torch.topk(probabilities, top_k)
top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
prob_text = "Prochains tokens les plus probables :\n\n"
for word, prob in prob_data.items():
prob_text += f"{word}: {prob:.2%}\n"
prob_plot = plot_probabilities(prob_data)
attention_plot = plot_attention(inputs["input_ids"][0], last_token_logits)
return prob_text, attention_plot, prob_plot
except Exception as e:
return f"Erreur lors de l'analyse : {str(e)}", None, None
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
if model is None or tokenizer is None:
return "Veuillez d'abord charger un modèle."
# Détection de la langue
detected_lang = detect(input_text)
if detected_lang not in model_languages.get(model.config._name_or_path, []):
return f"Langue détectée ({detected_lang}) non supportée par ce modèle."
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
try:
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"Erreur lors de la génération : {str(e)}"
def plot_probabilities(prob_data):
words = list(prob_data.keys())
probs = list(prob_data.values())
fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(range(len(words)), probs, color='lightgreen')
ax.set_title("Probabilités des tokens suivants les plus probables")
ax.set_xlabel("Tokens")
ax.set_ylabel("Probabilité")
ax.set_xticks(range(len(words)))
ax.set_xticklabels(words, rotation=45, ha='right')
for i, (bar, word) in enumerate(zip(bars, words)):
height = bar.get_height()
ax.text(i, height, f'{height:.2%}',
ha='center', va='bottom', rotation=0)
plt.tight_layout()
return fig
def plot_attention(input_ids, last_token_logits):
input_tokens = [ensure_token_display(tokenizer.decode([id])) for id in input_ids]
attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
top_k = min(len(input_tokens), 10)
top_attention_scores, _ = torch.topk(attention_scores, top_k)
fig, ax = plt.subplots(figsize=(14, 7))
sns.heatmap(top_attention_scores.unsqueeze(0).cpu().numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
cbar = ax.collections[0].colorbar
cbar.set_label("Score d'attention", fontsize=12)
cbar.ax.tick_params(labelsize=10)
plt.tight_layout()
return fig
def reset():
global model, tokenizer
model = None
tokenizer = None
return "", 1.0, 1.0, 50, None, None, None, None
with gr.Blocks() as demo:
gr.Markdown("# Analyse et génération de texte avec LLM")
with gr.Accordion("Sélection du modèle"):
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée", lines=3)
analyze_button = gr.Button("Analyser le prochain token")
next_token_probs = gr.Textbox(label="Probabilités du prochain token")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
generate_button = gr.Button("Générer la suite du texte")
generated_text = gr.Textbox(label="Texte généré")
reset_button = gr.Button("Réinitialiser")
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
analyze_button.click(analyze_next_token,
inputs=[input_text, temperature, top_p, top_k],
outputs=[next_token_probs, attention_plot, prob_plot])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[generated_text])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
if __name__ == "__main__":
demo.launch()