Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score | |
from torch.utils.data import DataLoader | |
import matplotlib.pyplot as plt | |
from atoms_detection.training_model import model_pipeline, get_args | |
from atoms_detection.dataset import CropsDataset | |
from atoms_detection.training import test_epoch | |
from utils.cf_matrix import make_confusion_matrix | |
from utils.paths import MODELS_PATH, CM_VIS_PATH | |
def main(args): | |
# CUDA for PyTorch | |
#use_cuda = torch.cuda.is_available() | |
use_cuda = torch.backends.mps.is_available() | |
device = torch.device("mps" if use_cuda else "cpu") | |
test_dataset = CropsDataset.test_dataset() | |
test_dataloader = DataLoader(test_dataset, batch_size=64) | |
ckpt_filename = os.path.join(MODELS_PATH, f'{args.experiment_name}.ckpt') | |
checkpoint = torch.load(ckpt_filename, map_location=device) | |
model = model_pipeline[args.model](num_classes=test_dataset.get_n_labels()).to(device) | |
model.load_state_dict(checkpoint['state_dict']) | |
if torch.cuda.device_count() > 1: | |
print("Using {} GPUs!".format(torch.cuda.device_count())) | |
model = torch.nn.DataParallel(model) | |
loss_function = torch.nn.CrossEntropyLoss(reduction='mean').to(device) | |
y_true, y_pred = test_epoch(test_dataloader, model, loss_function, device) | |
cm = confusion_matrix(y_true, y_pred) | |
labels = ["True Neg", "False Pos", "False Neg", "True Pos"] | |
make_confusion_matrix(cm, group_names=labels, cbar_range=(0, 110)) | |
if not os.path.exists(CM_VIS_PATH): | |
os.makedirs(CM_VIS_PATH) | |
plt.savefig(os.path.join(CM_VIS_PATH, f"cm_{args.experiment_name}.jpg")) | |
f1 = f1_score(y_true, y_pred) | |
acc = accuracy_score(y_true, y_pred) | |
with open(os.path.join(CM_VIS_PATH, f"metrics_{args.experiment_name}.txt"), 'w') as _log: | |
_log.write(f"F1_score: {f1}\nACCURACY: {acc}\n") | |
print(f"F1_score: {f1}") | |
print(f"ACCURACY: {acc}") | |
if __name__ == "__main__": | |
main(get_args()) | |