|
|
|
|
|
|
|
|
|
import torch |
|
|
|
def compute_batch_accuracy(pred, label): |
|
correct = (pred == label).sum() |
|
return correct,label.size(0) |
|
|
|
def compute_set_accuracy(model, test_loader): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
correct = 0 |
|
total = 0 |
|
with torch.no_grad(): |
|
for data in test_loader: |
|
inputs, labels = data |
|
|
|
inputs = inputs.to(device) |
|
labels = labels.to(device) |
|
outputs = model(inputs) |
|
|
|
correct_batch, total_batch = compute_batch_accuracy(torch.argmax(outputs, dim=1), labels) |
|
correct += correct_batch |
|
total += total_batch |
|
|
|
return correct/total |