soutrik's picture
added: model and code and app
29730dd
"""Module to define utility functions for the project."""
import os
import torch
def get_num_workers(model_run_location):
"""Given a run mode, return the number of workers to be used for data loading."""
# calculate the number of workers
num_workers = (os.cpu_count() - 1) if os.cpu_count() > 3 else 2
# If run_mode is local, use only 2 workers
num_workers = num_workers if model_run_location == "colab" else 0
return num_workers
# Function to save the model
# https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/
def save_model(epoch, model, optimizer, scheduler, batch_size, criterion, file_name):
"""
Function to save the trained model along with other information to disk.
"""
# print(f"Saving model from epoch {epoch}...")
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"batch_size": batch_size,
"loss": criterion,
},
file_name,
)
# Given a list of train_losses, train_accuracies, test_losses,
# test_accuracies, loop through epoch and print the metrics
def pretty_print_metrics(num_epochs, results):
"""
Function to print the metrics in a pretty format.
"""
# Extract train_losses, train_acc, test_losses, test_acc from results
train_losses = results["train_loss"]
train_acc = results["train_acc"]
test_losses = results["test_loss"]
test_acc = results["test_acc"]
for i in range(num_epochs):
print(
f"Epoch: {i+1:02d}, Train Loss: {train_losses[i]:.4f}, "
f"Test Loss: {test_losses[i]:.4f}, Train Accuracy: {train_acc[i]:.4f}, "
f"Test Accuracy: {test_acc[i]:.4f}"
)
# Given a file path, extract the folder path and create folder recursively if it does not already exist
def create_folder_if_not_exists(file_path):
"""
Function to create a folder if it does not exist.
"""
# Extract the folder path
folder_path = os.path.dirname(file_path)
print(f"Folder path: {folder_path}")
# Create the folder if it does not exist
if not os.path.exists(folder_path):
os.makedirs(folder_path,exist_ok=True)
print(f"Created folder: {folder_path}")