import torch import wandb import yaml from transformers import Trainer, TrainingArguments from data.datasets import load_and_tokenize_data from models.full_finetune_model import get_full_finetune_model from models.student_model import get_student_model # Charger la configuration with open('config/config.yaml', 'r') as f: config = yaml.safe_load(f) # Initialiser wandb wandb.init(project=config['wandb']['project'], entity=config['wandb']['entity']) # Charger les données train_dataset, test_dataset = load_and_tokenize_data(config) # Charger le modèle teacher et le modèle student teacher_model = get_full_finetune_model() student_model = get_student_model(config) # Définir les arguments de formation pour la distillation training_args = TrainingArguments( output_dir='./results_student', num_train_epochs=config['training']['num_epochs'], per_device_train_batch_size=config['training']['batch_size'], per_device_eval_batch_size=config['training']['batch_size'], evaluation_strategy='epoch', save_steps=10_000, save_total_limit=2, logging_dir='./logs', logging_steps=10, ) # Définir le distillateur class DistillationTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): # Forward pass of teacher model with torch.no_grad(): teacher_outputs = teacher_model(**inputs) # Forward pass of student model student_outputs = model(**inputs) # Compute distillation loss loss = torch.nn.functional.kl_div( torch.nn.functional.log_softmax(student_outputs.logits, dim=-1), torch.nn.functional.softmax(teacher_outputs.logits, dim=-1), reduction='batchmean' ) return (loss, student_outputs) if return_outputs else loss # Créer le Trainer pour la distillation trainer = DistillationTrainer( model=student_model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, ) # Mesurer les ressources et entraîner le modèle student measure_resources(trainer, "Distillation")