Object Detection
vision
File size: 1,037 Bytes
780c589
 
df8cf63
780c589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
###########################################################################
# Computer vision - Embedded person tracking demo software by HyperbeeAI. #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai #
###########################################################################
import torch

def compute_batch_accuracy(pred, label):
    correct = (pred == label).sum()
    return correct,label.size(0)

def compute_set_accuracy(model, test_loader):
    device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    correct = 0
    total   = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            
            inputs    = inputs.to(device)
            labels    = labels.to(device)
            outputs   = model(inputs)
            
            correct_batch, total_batch = compute_batch_accuracy(torch.argmax(outputs, dim=1), labels)
            correct += correct_batch
            total   += total_batch
            
    return correct/total