import copy import numpy as np import time import torch import torchvision.transforms.functional as F import matplotlib.pyplot as plt from modules.eval import main_evaluation from torch.optim import SGD, AdamW from torchvision.models.detection import keypointrcnn_resnet50_fpn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor from torchvision.models.detection import fasterrcnn_resnet50_fpn from tqdm import tqdm from modules.utils import write_results def get_arrow_model(num_classes, num_keypoints=2): """ Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints. Parameters: - num_classes (int): Number of classes for the model to detect, excluding the background class. - num_keypoints (int): Number of keypoints to predict for each detected object. Returns: - model (torch.nn.Module): The modified Keypoint R-CNN model. """ # Load a model pre-trained on COCO, initialized without pre-trained weights model = keypointrcnn_resnet50_fpn(weights=None) # Get the number of input features for the classifier in the box predictor. in_features = model.roi_heads.box_predictor.cls_score.in_features # Replace the box predictor in the ROI heads with a new one, tailored to the number of classes. model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # Replace the keypoint predictor in the ROI heads with a new one, specifically designed for the desired number of keypoints. model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512, num_keypoints) return model def get_faster_rcnn_model(num_classes): """ Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes. Parameters: - num_classes (int): Number of classes for the model to detect, including the background class. Returns: - model (torch.nn.Module): The modified Faster R-CNN model. """ # Load a pre-trained Faster R-CNN model model = fasterrcnn_resnet50_fpn(weights=None) # Get the number of input features for the classifier in the box predictor in_features = model.roi_heads.box_predictor.cls_score.in_features # Replace the box predictor with a new one, tailored to the number of classes (num_classes includes the background) model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model def prepare_model(dict, opti, learning_rate=0.0003, model_to_load=None, model_type='object'): """ Prepares the model and optimizer for training. Parameters: - dict (dict): Dictionary of classes. - opti (str): Optimizer type ('SGD' or 'Adam'). - learning_rate (float): Learning rate for the optimizer. - model_to_load (str, optional): Name of the model to load. - model_type (str): Type of model to prepare ('object' or 'arrow'). Returns: - model (torch.nn.Module): The prepared model. - optimizer (torch.optim.Optimizer): The configured optimizer. - device (torch.device): The device (CPU or CUDA) on which to perform training. """ # Adjusted to pass the class_dict directly if model_type == 'object': model = get_faster_rcnn_model(len(dict)) elif model_type == 'arrow': model = get_arrow_model(len(dict), 2) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # Load the model weights if model_to_load: model.load_state_dict(torch.load(model_to_load + '.pth', map_location=device)) print(f"Model '{model_to_load}' loaded") model.to(device) if opti == 'SGD': optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001) elif opti == 'Adam': optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.001, eps=1e-08, betas=(0.9, 0.999)) else: print('Optimizer not found') return model, optimizer, device def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False): """ Evaluate the loss of the model on a validation dataset. Parameters: - model (torch.nn.Module): The model to evaluate. - data_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset. - device (torch.device): Device to perform evaluation on. - loss_config (dict, optional): Configuration specifying which losses to use. - print_losses (bool): Whether to print individual loss components. Returns: - float: Average loss over the validation dataset. """ model.train() # Set the model to evaluation mode total_loss = 0 # Initialize lists to keep track of individual losses loss_classifier_list = [] loss_box_reg_list = [] loss_objectness_list = [] loss_rpn_box_reg_list = [] loss_keypoints_list = [] with torch.no_grad(): # Disable gradient computation for images, targets_im in tqdm(data_loader, desc="Evaluating"): images = [image.to(device) for image in images] targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im] loss_dict = model(images, targets) # Calculate the total loss for the current batch losses = 0 if loss_config is not None: for key, loss in loss_dict.items(): if loss_config.get(key, False): losses += loss else: losses = sum(loss for key, loss in loss_dict.items()) total_loss += losses.item() # Collect individual losses if loss_dict.get('loss_classifier') is not None: loss_classifier_list.append(loss_dict['loss_classifier'].item()) else: loss_classifier_list.append(0) if loss_dict.get('loss_box_reg') is not None: loss_box_reg_list.append(loss_dict['loss_box_reg'].item()) else: loss_box_reg_list.append(0) if loss_dict.get('loss_objectness') is not None: loss_objectness_list.append(loss_dict['loss_objectness'].item()) else: loss_objectness_list.append(0) if loss_dict.get('loss_rpn_box_reg') is not None: loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item()) else: loss_rpn_box_reg_list.append(0) if 'loss_keypoint' in loss_dict: loss_keypoints_list.append(loss_dict['loss_keypoint'].item()) else: loss_keypoints_list.append(0) # Calculate average loss avg_loss = total_loss / len(data_loader) avg_loss_classifier = np.mean(loss_classifier_list) avg_loss_box_reg = np.mean(loss_box_reg_list) avg_loss_objectness = np.mean(loss_objectness_list) avg_loss_rpn_box_reg = np.mean(loss_rpn_box_reg_list) avg_loss_keypoints = np.mean(loss_keypoints_list) if print_losses: print(f"Average Loss: {avg_loss:.4f}") print(f"Average Classifier Loss: {avg_loss_classifier:.4f}") print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}") print(f"Average Objectness Loss: {avg_loss_objectness:.4f}") print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}") print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}") return avg_loss def training_model(num_epochs, model, data_loader, subset_test_loader, optimizer, model_to_load=None, change_learning_rate=100, start_key=100, save_every=5, parameters=None, blur_prob=0.02, score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97, information_training='training', start_epoch=0, loss_config=None, model_type='object', eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')): # Set the model to training mode model.train() if loss_config is None: print('No loss config found, all losses will be used.') else: # Print the list of the losses that will be used print('The following losses will be used: ', end='') for key, value in loss_config.items(): if value: print(key, end=", ") print() # Initialize lists to store epoch-wise average losses and other metrics epoch_avg_losses = [] epoch_avg_loss_classifier = [] epoch_avg_loss_box_reg = [] epoch_avg_loss_objectness = [] epoch_avg_loss_rpn_box_reg = [] epoch_avg_loss_keypoints = [] epoch_precision = [] epoch_recall = [] epoch_f1_score = [] epoch_test_loss = [] start_tot = time.time() best_metric_value = -1000 best_epoch = 0 best_model_state = None epochs_with_high_f1 = 0 learning_rate = optimizer.param_groups[0]['lr'] bad_test_loss_epochs = 0 previous_test_loss = 1000 if parameters is not None: batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values() print(f"Let's go training {model_type} model with {num_epochs} epochs!") if parameters is not None: print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}") for epoch in range(num_epochs): if (epoch > 0 and epoch % change_learning_rate == 0) or bad_test_loss_epochs >= 2: learning_rate *= 0.7 optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999)) if best_model_state is not None: model.load_state_dict(best_model_state) print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}') bad_test_loss_epochs = 0 if epoch > 0 and epoch == start_key: print("Now it's training Keypoints also") loss_config['loss_keypoint'] = True for name, param in model.named_parameters(): if 'keypoint' in name: param.requires_grad = True model.train() start = time.time() total_loss = 0 # Initialize lists to keep track of individual losses loss_classifier_list = [] loss_box_reg_list = [] loss_objectness_list = [] loss_rpn_box_reg_list = [] loss_keypoints_list = [] # Create a tqdm progress bar progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}') for images, targets_im in progress_bar: images = [image.to(device) for image in images] targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im] optimizer.zero_grad() loss_dict = model(images, targets) # Inside the training loop where losses are calculated: losses = 0 if loss_config is not None: for key, loss in loss_dict.items(): if loss_config.get(key, False): if key == 'loss_classifier': loss *= 3 losses += loss else: losses = sum(loss for key, loss in loss_dict.items()) # Collect individual losses loss_classifier_list.append(loss_dict.get('loss_classifier', torch.tensor(0)).item()) loss_box_reg_list.append(loss_dict.get('loss_box_reg', torch.tensor(0)).item()) loss_objectness_list.append(loss_dict.get('loss_objectness', torch.tensor(0)).item()) loss_rpn_box_reg_list.append(loss_dict.get('loss_rpn_box_reg', torch.tensor(0)).item()) loss_keypoints_list.append(loss_dict.get('loss_keypoint', torch.tensor(0)).item()) losses.backward() optimizer.step() total_loss += losses.item() # Update the description with the current loss progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}') # Calculate average loss avg_loss = total_loss / len(data_loader) epoch_avg_losses.append(avg_loss) epoch_avg_loss_classifier.append(np.mean(loss_classifier_list)) epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list)) epoch_avg_loss_objectness.append(np.mean(loss_objectness_list)) epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list)) epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list)) # Evaluate the model on the test set if eval_metric == 'loss': labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0, 0, 0, 0, 0, 0 avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config) print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ") else: avg_test_loss = 0 labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_threshold=10, key_correction=False, model_type=model_type) print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ") avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config) print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ") print(f"Time: {time.time() - start:.2f} [s]") if eval_metric == 'f1_score': metric_used = f1_score elif eval_metric == 'precision': metric_used = precision elif eval_metric == 'recall': metric_used = recall else: metric_used = -avg_test_loss # Check if this epoch's model has the best evaluation metric if metric_used > best_metric_value: best_metric_value = metric_used best_epoch = epoch + 1 + start_epoch best_model_state = copy.deepcopy(model.state_dict()) if epoch > 0 and f1_score > early_stop_f1_score: epochs_with_high_f1 += 1 epoch_precision.append(precision) epoch_recall.append(recall) epoch_f1_score.append(f1_score) epoch_test_loss.append(avg_test_loss) name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}" metrics_list = [epoch_avg_losses, epoch_avg_loss_classifier, epoch_avg_loss_box_reg, epoch_avg_loss_objectness, epoch_avg_loss_rpn_box_reg, epoch_avg_loss_keypoints, epoch_precision, epoch_recall, epoch_f1_score, epoch_test_loss] if epochs_with_high_f1 >= 1: torch.save(best_model_state, './models/' + name_model + '.pth') write_results(name_model, metrics_list, start_epoch) break if (epoch + 1 + start_epoch) % save_every == 0: torch.save(best_model_state, './models/' + name_model + '.pth') model.load_state_dict(best_model_state) write_results(name_model, metrics_list, start_epoch) if avg_test_loss > previous_test_loss: bad_test_loss_epochs += 1 previous_test_loss = avg_test_loss print(f"\nTotal time: {(time.time() - start_tot) / 60:.2f} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metric_value:.4f}") if best_model_state: torch.save(best_model_state, './models/' + name_model + '.pth') model.load_state_dict(best_model_state) write_results(name_model, metrics_list, start_epoch) print(f"Name of the best model: {name_model}") return model