File size: 1,009 Bytes
14e27af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
from nail_detection.main import get_nails

from DummyModel import DummyModel


def load_model(DEBUG):
    model = DummyModel()
    if not DEBUG:
        file_path = hf_hub_download("lfolle/DeepNAPSIModel", "dummy_model.pth",
                                    use_auth_token=os.environ['DeepNAPSIModel'])
        model.load_state_dict(torch.load(file_path))
    return model


class Infer():
    def __init__(self, DEBUG):
        self.model = load_model(DEBUG)

    def predict(self, data):
        nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
        predictions = []
        if nails is None:
            for _ in range(5):
                predictions.append(np.zeros((64, 64, 3)))
                predictions.append(-1)
        else:
            for nail in nails:
                predictions.append(nail)
                predictions.append(int(torch.argmax(self.model(nail))))
        return predictions