Spaces:
Sleeping
Sleeping
import os | |
import torch | |
def get_dataset_labels(): | |
return ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'] | |
def get_data_label_name(idx): | |
if idx < 0: | |
return '' | |
return get_dataset_labels()[idx] | |
def get_data_idx_from_name(name): | |
if not name: | |
return -1 | |
return get_dataset_labels.index(name.lower()) if name.lower() in get_dataset_labels() else -1 | |
def load_model_from_checkpoint(device, file_name='checkpoint.ckpt'): | |
checkpoint = torch.load('ckpt.pth', map_location=device) | |
return checkpoint | |
def denormalize(img, mean, std): | |
MEAN = torch.tensor(mean) | |
STD = torch.tensor(std) | |
img = img * STD[:, None, None] + MEAN[:, None, None] | |
i_min = img.min().item() | |
i_max = img.max().item() | |
img_bar = (img - i_min)/(i_max - i_min) | |
return img_bar | |
# Data to plot accuracy and loss graphs | |
train_losses = [] | |
test_losses = [] | |
train_acc = [] | |
test_acc = [] | |
test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []} | |
test_correct_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []} | |
def get_correct_pred_count(pPrediction, pLabels): | |
return pPrediction.argmax(dim=1).eq(pLabels).sum().item() | |
def add_predictions(data, pred, target): | |
diff_preds = pred.argmax(dim=1) - target | |
for idx, d in enumerate(diff_preds): | |
if d.item() != 0: | |
test_incorrect_pred['images'].append(data[idx]) | |
test_incorrect_pred['ground_truths'].append(target[idx]) | |
test_incorrect_pred['predicted_vals'].append(pred.argmax(dim=1)[idx]) | |
elif d.item() == 0: | |
test_correct_pred['images'].append(data[idx]) | |
test_correct_pred['ground_truths'].append(target[idx]) | |
test_correct_pred['predicted_vals'].append(pred.argmax(dim=1)[idx]) | |