s12erav1 / utils.py
piyushgrover's picture
Uploaded app code
f7915f2
import os
import torch
def get_dataset_labels():
return ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def get_data_label_name(idx):
if idx < 0:
return ''
return get_dataset_labels()[idx]
def get_data_idx_from_name(name):
if not name:
return -1
return get_dataset_labels.index(name.lower()) if name.lower() in get_dataset_labels() else -1
def load_model_from_checkpoint(device, file_name='checkpoint.ckpt'):
checkpoint = torch.load('ckpt.pth', map_location=device)
return checkpoint
def denormalize(img, mean, std):
MEAN = torch.tensor(mean)
STD = torch.tensor(std)
img = img * STD[:, None, None] + MEAN[:, None, None]
i_min = img.min().item()
i_max = img.max().item()
img_bar = (img - i_min)/(i_max - i_min)
return img_bar
# Data to plot accuracy and loss graphs
train_losses = []
test_losses = []
train_acc = []
test_acc = []
test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
test_correct_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
def get_correct_pred_count(pPrediction, pLabels):
return pPrediction.argmax(dim=1).eq(pLabels).sum().item()
def add_predictions(data, pred, target):
diff_preds = pred.argmax(dim=1) - target
for idx, d in enumerate(diff_preds):
if d.item() != 0:
test_incorrect_pred['images'].append(data[idx])
test_incorrect_pred['ground_truths'].append(target[idx])
test_incorrect_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])
elif d.item() == 0:
test_correct_pred['images'].append(data[idx])
test_correct_pred['ground_truths'].append(target[idx])
test_correct_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])