atom-detection / atoms_detection /testing_model.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
raw
history blame
1.99 kB
import os
import torch
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from atoms_detection.training_model import model_pipeline, get_args
from atoms_detection.dataset import CropsDataset
from atoms_detection.training import test_epoch
from utils.cf_matrix import make_confusion_matrix
from utils.paths import MODELS_PATH, CM_VIS_PATH
def main(args):
# CUDA for PyTorch
#use_cuda = torch.cuda.is_available()
use_cuda = torch.backends.mps.is_available()
device = torch.device("mps" if use_cuda else "cpu")
test_dataset = CropsDataset.test_dataset()
test_dataloader = DataLoader(test_dataset, batch_size=64)
ckpt_filename = os.path.join(MODELS_PATH, f'{args.experiment_name}.ckpt')
checkpoint = torch.load(ckpt_filename, map_location=device)
model = model_pipeline[args.model](num_classes=test_dataset.get_n_labels()).to(device)
model.load_state_dict(checkpoint['state_dict'])
if torch.cuda.device_count() > 1:
print("Using {} GPUs!".format(torch.cuda.device_count()))
model = torch.nn.DataParallel(model)
loss_function = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
y_true, y_pred = test_epoch(test_dataloader, model, loss_function, device)
cm = confusion_matrix(y_true, y_pred)
labels = ["True Neg", "False Pos", "False Neg", "True Pos"]
make_confusion_matrix(cm, group_names=labels, cbar_range=(0, 110))
if not os.path.exists(CM_VIS_PATH):
os.makedirs(CM_VIS_PATH)
plt.savefig(os.path.join(CM_VIS_PATH, f"cm_{args.experiment_name}.jpg"))
f1 = f1_score(y_true, y_pred)
acc = accuracy_score(y_true, y_pred)
with open(os.path.join(CM_VIS_PATH, f"metrics_{args.experiment_name}.txt"), 'w') as _log:
_log.write(f"F1_score: {f1}\nACCURACY: {acc}\n")
print(f"F1_score: {f1}")
print(f"ACCURACY: {acc}")
if __name__ == "__main__":
main(get_args())