|
import torch |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
def pred_label(model, img): |
|
model = model.to('cpu') |
|
img = img.unsqueeze(0) |
|
logits = model(img) |
|
pred_probab = torch.nn.Softmax(dim=1)(logits) |
|
y_pred = pred_probab.argmax(1) |
|
|
|
return y_pred |
|
|
|
def save_image(img, title, count): |
|
fig, ax = plt.subplots() |
|
imgplot = ax.imshow(img, interpolation='bicubic') |
|
ax.spines['top'].set_visible(False) |
|
ax.spines['left'].set_visible(False) |
|
ax.spines['bottom'].set_visible(False) |
|
ax.spines['right'].set_visible(False) |
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
|
|
plt.title(title) |
|
plt.savefig('./viz/img' + str(count)) |
|
|
|
|
|
def observations(model, testloader): |
|
for imgs, labels in testloader: |
|
images = [imgs[0].permute(1, 2, 0), |
|
imgs[1].permute(1, 2, 0), |
|
imgs[2].permute(1, 2, 0), |
|
imgs[3].permute(1, 2, 0), |
|
imgs[4].permute(1, 2, 0)] |
|
|
|
pred_label1 = pred_label(model, imgs[0]).item() |
|
pred_label2 = pred_label(model, imgs[1]).item() |
|
pred_label3 = pred_label(model, imgs[2]).item() |
|
pred_label4 = pred_label(model, imgs[3]).item() |
|
pred_label5 = pred_label(model, imgs[4]).item() |
|
|
|
titles = ["Pred: {}, Actual: {}".format(pred_label1, labels[0]), |
|
"Pred: {}, Actual: {}".format(pred_label2, labels[1]), |
|
"Pred: {}, Actual: {}".format(pred_label3, labels[2]), |
|
"Pred: {}, Actual: {}".format(pred_label4, labels[3]), |
|
"Pred: {}, Actual: {}".format(pred_label5, labels[4])] |
|
|
|
count = 1 |
|
for image, title in zip(images, titles): |
|
save_image(image, title, count) |
|
count += 1 |
|
|
|
break |