Spaces:
Running
Running
StevenLimcorn
commited on
Commit
•
44b4267
1
Parent(s):
146b45b
Initial Commit
Browse files- app.py +40 -0
- model.pth +3 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
model = torch.load("/content/drive/MyDrive/Mask Detection/model.pth", map_location=torch.device("cpu"))
|
8 |
+
IMG_SIZE = 224
|
9 |
+
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]
|
10 |
+
|
11 |
+
transforms_test = transforms.Compose(
|
12 |
+
[
|
13 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
14 |
+
transforms.ToTensor(),
|
15 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
16 |
+
]
|
17 |
+
)
|
18 |
+
|
19 |
+
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]
|
20 |
+
|
21 |
+
def predict_image(image):
|
22 |
+
transformed_tensor = torch.unsqueeze(transforms_test(image), 0)
|
23 |
+
logits = model(transformed_tensor)
|
24 |
+
probability = torch.flatten(F.softmax(logits, dim=1)).detach().cpu().numpy()
|
25 |
+
print(probability)
|
26 |
+
labels = {A: B.item() for A, B in zip(MASK_LABEL, probability)}
|
27 |
+
sorted_labels = dict(sorted(labels.items(), key=lambda item: item[1], reverse=True))
|
28 |
+
print(sorted_labels)
|
29 |
+
return sorted_labels
|
30 |
+
|
31 |
+
title = "ViT Mask Detection"
|
32 |
+
description = "Gradio demo for ViT-16 Mask Image Classification created by <a href='https://github.com/stevenlimcorn'>Steven Limcorn</a>"
|
33 |
+
article = "An Application made by stevenlimcorn. Notebook access at: <a href='https://github.com/stevenlimcorn/Mask-Classification'></a>"
|
34 |
+
|
35 |
+
demo = gr.Interface(predict_image,
|
36 |
+
inputs=gr.Image(label="Input Image", type="pil", source="webcam"),
|
37 |
+
outputs=gr.Label(), title=title, description=description, article=article
|
38 |
+
)
|
39 |
+
|
40 |
+
demo.launch()
|
model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c2cee68dc4777f9133fe97ccca2414e66d20f628fdef4efcef99bfac9408259b
|
3 |
+
size 343285383
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
timm==0.4.12
|
2 |
+
torch==1.10.1
|
3 |
+
gradio
|
4 |
+
numpy
|
5 |
+
torchvision
|