lfolle commited on
Commit
761b08f
1 Parent(s): dd159eb

Added napsi sum, small refactoring.

Browse files
Files changed (4) hide show
  1. DummyModel.py +13 -4
  2. Model.py +13 -0
  3. app.py +1 -0
  4. backend.py +8 -13
DummyModel.py CHANGED
@@ -2,12 +2,21 @@ import torch
2
  import torch.nn
3
 
4
 
 
 
 
 
 
 
 
 
 
5
  class DummyModel(torch.nn.Module):
6
  def __init__(self):
7
  super().__init__()
8
 
9
- def forward(self, x):
10
- return torch.softmax(torch.rand(5), 0)
11
 
12
- def __call__(self, x):
13
- return self.forward(x)
 
2
  import torch.nn
3
 
4
 
5
+ def load_dummy_model(DEBUG):
6
+ model = DummyModel()
7
+ if not DEBUG:
8
+ file_path = hf_hub_download("lfolle/DeepNAPSIModel", "dummy_model.pth",
9
+ use_auth_token=os.environ['DeepNAPSIModel'])
10
+ model.load_state_dict(torch.load(file_path))
11
+ return model
12
+
13
+
14
  class DummyModel(torch.nn.Module):
15
  def __init__(self):
16
  super().__init__()
17
 
18
+ def forward(self, x:list):
19
+ return torch.softmax(torch.rand(len(x), 5), 1)
20
 
21
+ def __call__(self, x:list):
22
+ return self.forward(x)
Model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nail_classification.inference import Inference
2
+
3
+
4
+ class Model:
5
+ def __init__(self):
6
+ file_paths = [hf_hub_download("lfolle/DeepNAPSIModel", f"version_{v}.ckpt",
7
+ use_auth_token=os.environ['DeepNAPSIModel']) for v in [10, 11, 12, 13, 14]]
8
+ self.inference = Inference(file_paths)
9
+
10
+
11
+ def predict(self, x):
12
+ y_hat, uncertainty = self.inference.predict(x)
13
+ return y_hat, uncertainty
app.py CHANGED
@@ -15,6 +15,7 @@ with gr.Blocks(analytics_enabled=False, title="DeepNAPSI Prediction") as demo:
15
  image_button = gr.Button("Predict NAPSI")
16
  outputs = []
17
  with gr.Row():
 
18
  with gr.Column():
19
  outputs.append(gr.Image())
20
  outputs.append(gr.Number(label="DeepNAPSI Thumb"))
 
15
  image_button = gr.Button("Predict NAPSI")
16
  outputs = []
17
  with gr.Row():
18
+ outputs.append(gr.Number(label="DeepNAPSI Sum"))
19
  with gr.Column():
20
  outputs.append(gr.Image())
21
  outputs.append(gr.Number(label="DeepNAPSI Thumb"))
backend.py CHANGED
@@ -4,31 +4,26 @@ import numpy as np
4
  from huggingface_hub import hf_hub_download
5
  from nail_detection.main import get_nails
6
 
7
- from DummyModel import DummyModel
8
-
9
-
10
- def load_model(DEBUG):
11
- model = DummyModel()
12
- if not DEBUG:
13
- file_path = hf_hub_download("lfolle/DeepNAPSIModel", "dummy_model.pth",
14
- use_auth_token=os.environ['DeepNAPSIModel'])
15
- model.load_state_dict(torch.load(file_path))
16
- return model
17
 
18
 
19
  class Infer():
20
  def __init__(self, DEBUG):
21
- self.model = load_model(DEBUG)
22
 
23
  def predict(self, data):
24
  nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
25
  predictions = []
26
  if nails is None:
 
27
  for _ in range(5):
28
  predictions.append(np.zeros((64, 64, 3)))
29
  predictions.append(-1)
30
  else:
31
- for nail in nails:
 
 
 
32
  predictions.append(nail)
33
- predictions.append(int(torch.argmax(self.model(nail))))
34
  return predictions
 
4
  from huggingface_hub import hf_hub_download
5
  from nail_detection.main import get_nails
6
 
7
+ from DummyModel import load_dummy_model
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  class Infer():
11
  def __init__(self, DEBUG):
12
+ self.model = load_dummy_model(DEBUG)
13
 
14
  def predict(self, data):
15
  nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
16
  predictions = []
17
  if nails is None:
18
+ predictions.append(-1)
19
  for _ in range(5):
20
  predictions.append(np.zeros((64, 64, 3)))
21
  predictions.append(-1)
22
  else:
23
+ napsi_predictions = torch.argmax(self.model(nails), 1)
24
+ napsi_sum = int(napsi_predictions.sum().detach().cpu())
25
+ predictions.append(napsi_sum)
26
+ for napsi_prediction, nail in zip(napsi_predictions, nails):
27
  predictions.append(nail)
28
+ predictions.append(napsi_prediction)
29
  return predictions