jcarnero commited on
Commit
3f9971c
1 Parent(s): 0f43701

Updated deployment

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +10 -2
  3. vit_saved.pth +1 -1
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Birds classifier
3
  emoji: 🐦
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.20.1
8
  app_file: app.py
 
1
  ---
2
  title: Birds classifier
3
  emoji: 🐦
4
+ colorFrom: white
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.20.1
8
  app_file: app.py
app.py CHANGED
@@ -14,6 +14,15 @@ to_tensor = ToTensor()
14
  norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
15
 
16
 
 
 
 
 
 
 
 
 
 
17
  def classify_image(inp):
18
  inp = Image.fromarray(inp)
19
  transformed_input = resized_crop_pad(inp, (460, 460))
@@ -23,8 +32,7 @@ def classify_image(inp):
23
  model.eval()
24
  with torch.no_grad():
25
  pred = model(transformed_input)
26
- pred = torch.argmax(pred, dim=1)
27
- return vocab[pred]
28
 
29
 
30
  iface = gr.Interface(
 
14
  norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
15
 
16
 
17
+ def decode_pred(pred: torch.Tensor) -> str:
18
+ indices = pred > 0.95
19
+ if indices.any():
20
+ # return first match
21
+ return vocab[indices.nonzero()[0]]
22
+ else:
23
+ return "I don't know what this is, ¡páharo!"
24
+
25
+
26
  def classify_image(inp):
27
  inp = Image.fromarray(inp)
28
  transformed_input = resized_crop_pad(inp, (460, 460))
 
32
  model.eval()
33
  with torch.no_grad():
34
  pred = model(transformed_input)
35
+ return decode_pred(torch.sigmoid(pred).squeeze(dim=0))
 
36
 
37
 
38
  iface = gr.Interface(
vit_saved.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:09e9c264c4b31db320125563076207358c54b1d460470f38f2ee2c6c196eafe1
3
  size 22974149
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8169a84843ea366cae9295d3a6b7870884fdca2169fe7b419d9552990a26942
3
  size 22974149