Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Login to Hugging Face with token | |
login(token=os.environ["HF_TOKEN"]) | |
# Liste des modèles | |
model_list = [ | |
"meta-llama/Llama-2-13b", | |
"meta-llama/Llama-2-7b", | |
"meta-llama/Llama-2-70b", | |
"meta-llama/Meta-Llama-3-8B", | |
"meta-llama/Llama-3.2-3B", | |
"meta-llama/Llama-3.1-8B", | |
"mistralai/Mistral-7B-v0.1", | |
"mistralai/Mixtral-8x7B-v0.1", | |
"mistralai/Mistral-7B-v0.3", | |
"google/gemma-2-2b", | |
"google/gemma-2-9b", | |
"google/gemma-2-27b", | |
"croissantllm/CroissantLLMBase" | |
] | |
# Charger le modèle et le tokenizer | |
model = None | |
tokenizer = None | |
def load_model(model_name): | |
global model, tokenizer | |
print(f"Chargement du modèle {model_name}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, attn_implementation="eager") | |
print("Modèle chargé avec succès.") | |
return f"Modèle {model_name} chargé." | |
def plot_attention(attention_data): | |
tokens = attention_data['tokens'] | |
attention = attention_data['attention'] | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
cax = ax.matshow(attention, cmap='viridis') | |
fig.colorbar(cax) | |
ax.set_xticklabels([''] + tokens, rotation=90) | |
ax.set_yticklabels([''] + tokens) | |
plt.xlabel("Tokens") | |
plt.ylabel("Tokens") | |
plt.title("Attention Heatmap") | |
plt.tight_layout() | |
plt.savefig('attention_plot.png') | |
return 'attention_plot.png' | |
def plot_probabilities(prob_data): | |
words, probs = zip(*prob_data.items()) | |
plt.figure(figsize=(6, 4)) | |
plt.barh(words, probs, color='skyblue') | |
plt.xlabel('Probabilities') | |
plt.title('Top Probable Words') | |
plt.tight_layout() | |
plt.savefig('probabilities_plot.png') | |
return 'probabilities_plot.png' | |
def generate_text(input_text, temperature, top_p, top_k): | |
global model, tokenizer | |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
output_scores=True, | |
output_attentions=True, | |
return_dict_in_generate=True, | |
return_legacy_cache=True | |
) | |
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) | |
# Logits et probabilités du dernier token généré | |
last_token_logits = outputs.scores[-1][0] | |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1) | |
# Top 5 des mots les plus probables | |
top_probs, top_indices = torch.topk(probabilities, 5) | |
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices] | |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)} | |
# Extraction des attentions | |
attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy() | |
attention_data = { | |
'attention': attentions, | |
'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0]) | |
} | |
return generated_text, plot_attention(attention_data), plot_probabilities(prob_data) | |
def reset_app(): | |
global model, tokenizer | |
model = None | |
tokenizer = None | |
return "Application réinitialisée." | |
# Interface utilisateur Gradio | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
model_selection = gr.Accordion("Sélection du modèle", open=True) | |
with model_selection: | |
model_name = gr.Dropdown(choices=model_list, label="Choisir un modèle", value=model_list[0]) | |
load_model_button = gr.Button("Charger le modèle") | |
load_status = gr.Textbox(label="Statut du modèle", interactive=False) | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 1.0, value=0.7, label="Température") | |
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p") | |
top_k = gr.Slider(1, 100, value=50, label="Top-k") | |
with gr.Row(): | |
input_text = gr.Textbox(label="Entrer le texte") | |
generate_button = gr.Button("Générer") | |
with gr.Row(): | |
output_text = gr.Textbox(label="Texte généré", interactive=False) | |
with gr.Row(): | |
attention_plot = gr.Image(label="Carte de chaleur des attentions") | |
prob_plot = gr.Image(label="Probabilités des mots les plus probables") | |
with gr.Row(): | |
reset_button = gr.Button("Réinitialiser l'application") | |
load_model_button.click(load_model, inputs=[model_name], outputs=[load_status]) | |
generate_button.click(generate_text, inputs=[input_text, temperature, top_p, top_k], outputs=[output_text, attention_plot, prob_plot]) | |
reset_button.click(reset_app) | |
demo.launch() |