File size: 2,369 Bytes
29730dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Module to define utility functions for the project."""
import os

import torch


def get_num_workers(model_run_location):
    """Given a run mode, return the number of workers to be used for data loading."""

    # calculate the number of workers
    num_workers = (os.cpu_count() - 1) if os.cpu_count() > 3 else 2

    # If run_mode is local, use only 2 workers
    num_workers = num_workers if model_run_location == "colab" else 0

    return num_workers


# Function to save the model
# https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/
def save_model(epoch, model, optimizer, scheduler, batch_size, criterion, file_name):
    """
    Function to save the trained model along with other information to disk.
    """
    # print(f"Saving model from epoch {epoch}...")
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "batch_size": batch_size,
            "loss": criterion,
        },
        file_name,
    )


# Given a list of train_losses, train_accuracies, test_losses,
# test_accuracies, loop through epoch and print the metrics
def pretty_print_metrics(num_epochs, results):
    """
    Function to print the metrics in a pretty format.
    """
    # Extract train_losses, train_acc, test_losses, test_acc from results
    train_losses = results["train_loss"]
    train_acc = results["train_acc"]
    test_losses = results["test_loss"]
    test_acc = results["test_acc"]

    for i in range(num_epochs):
        print(
            f"Epoch: {i+1:02d}, Train Loss: {train_losses[i]:.4f}, "
            f"Test Loss: {test_losses[i]:.4f}, Train Accuracy: {train_acc[i]:.4f}, "
            f"Test Accuracy: {test_acc[i]:.4f}"
        )


# Given a file path, extract the folder path and create folder recursively if it does not already exist
def create_folder_if_not_exists(file_path):
    """
    Function to create a folder if it does not exist.
    """
    # Extract the folder path
    folder_path = os.path.dirname(file_path)
    print(f"Folder path: {folder_path}")

    # Create the folder if it does not exist
    if not os.path.exists(folder_path):
        os.makedirs(folder_path,exist_ok=True)
        print(f"Created folder: {folder_path}")