import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login import matplotlib.pyplot as plt import numpy as np # Login to Hugging Face with token login(token=os.environ["HF_TOKEN"]) # Liste des modèles model_list = [ "meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b", "meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B", "mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3", "google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b", "croissantllm/CroissantLLMBase" ] # Charger le modèle et le tokenizer model = None tokenizer = None def load_model(model_name): global model, tokenizer print(f"Chargement du modèle {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, attn_implementation="eager") print("Modèle chargé avec succès.") return f"Modèle {model_name} chargé." def plot_attention(attention_data): tokens = attention_data['tokens'] attention = attention_data['attention'] fig, ax = plt.subplots(figsize=(10, 10)) cax = ax.matshow(attention, cmap='viridis') fig.colorbar(cax) ax.set_xticklabels([''] + tokens, rotation=90) ax.set_yticklabels([''] + tokens) plt.xlabel("Tokens") plt.ylabel("Tokens") plt.title("Attention Heatmap") plt.tight_layout() plt.savefig('attention_plot.png') return 'attention_plot.png' def plot_probabilities(prob_data): words, probs = zip(*prob_data.items()) plt.figure(figsize=(6, 4)) plt.barh(words, probs, color='skyblue') plt.xlabel('Probabilities') plt.title('Top Probable Words') plt.tight_layout() plt.savefig('probabilities_plot.png') return 'probabilities_plot.png' def generate_text(input_text, temperature, top_p, top_k): global model, tokenizer inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=50, temperature=temperature, top_p=top_p, top_k=top_k, output_scores=True, output_attentions=True, return_dict_in_generate=True, return_legacy_cache=True ) generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Logits et probabilités du dernier token généré last_token_logits = outputs.scores[-1][0] probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1) # Top 5 des mots les plus probables top_probs, top_indices = torch.topk(probabilities, 5) top_words = [tokenizer.decode([idx.item()]) for idx in top_indices] prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)} # Extraction des attentions attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy() attention_data = { 'attention': attentions, 'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0]) } return generated_text, plot_attention(attention_data), plot_probabilities(prob_data) def reset_app(): global model, tokenizer model = None tokenizer = None return "Application réinitialisée." # Interface utilisateur Gradio with gr.Blocks() as demo: with gr.Row(): model_selection = gr.Accordion("Sélection du modèle", open=True) with model_selection: model_name = gr.Dropdown(choices=model_list, label="Choisir un modèle", value=model_list[0]) load_model_button = gr.Button("Charger le modèle") load_status = gr.Textbox(label="Statut du modèle", interactive=False) with gr.Row(): temperature = gr.Slider(0.0, 1.0, value=0.7, label="Température") top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p") top_k = gr.Slider(1, 100, value=50, label="Top-k") with gr.Row(): input_text = gr.Textbox(label="Entrer le texte") generate_button = gr.Button("Générer") with gr.Row(): output_text = gr.Textbox(label="Texte généré", interactive=False) with gr.Row(): attention_plot = gr.Image(label="Carte de chaleur des attentions") prob_plot = gr.Image(label="Probabilités des mots les plus probables") with gr.Row(): reset_button = gr.Button("Réinitialiser l'application") load_model_button.click(load_model, inputs=[model_name], outputs=[load_status]) generate_button.click(generate_text, inputs=[input_text, temperature, top_p, top_k], outputs=[output_text, attention_plot, prob_plot]) reset_button.click(reset_app) demo.launch()