Spaces:
Running
Running
Upload 7 files
Browse files- app.py +10 -10
- model.py +1 -0
- pretrained_resnet50_feature_extractor_drappcompressed.pth +2 -2
app.py
CHANGED
@@ -25,6 +25,7 @@ class_names = ['CRVO',
|
|
25 |
resnet50, resnet50_transforms = create_resnet50_model(
|
26 |
num_classes=len(class_names), # actual value would also work
|
27 |
)
|
|
|
28 |
|
29 |
# Load saved weights
|
30 |
resnet50.load_state_dict(
|
@@ -34,6 +35,7 @@ resnet50.load_state_dict(
|
|
34 |
)
|
35 |
)
|
36 |
|
|
|
37 |
### 3. Predict function ###
|
38 |
|
39 |
# Create predict function
|
@@ -64,9 +66,9 @@ def predict(img) -> Tuple[Dict, float]:
|
|
64 |
### 4. Gradio app ###
|
65 |
|
66 |
# Create title, description and article strings
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
|
71 |
# Create examples list from "examples/" directory
|
72 |
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
@@ -77,12 +79,10 @@ demo = gr.Interface(fn=predict, # mapping function from input to output
|
|
77 |
outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
|
78 |
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
|
79 |
# Create examples list from "examples/" directory
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
.mx-auto {background-color: #0B0F19}
|
85 |
-
""")
|
86 |
|
87 |
# Launch the demo!
|
88 |
-
demo.launch()
|
|
|
25 |
resnet50, resnet50_transforms = create_resnet50_model(
|
26 |
num_classes=len(class_names), # actual value would also work
|
27 |
)
|
28 |
+
resnet50.fc = nn.Linear(2048, 10)
|
29 |
|
30 |
# Load saved weights
|
31 |
resnet50.load_state_dict(
|
|
|
35 |
)
|
36 |
)
|
37 |
|
38 |
+
|
39 |
### 3. Predict function ###
|
40 |
|
41 |
# Create predict function
|
|
|
66 |
### 4. Gradio app ###
|
67 |
|
68 |
# Create title, description and article strings
|
69 |
+
title = "DeepFundus 👀"
|
70 |
+
description = "A ResNet50 feature extractor computer vision model to classify funduscopic images."
|
71 |
+
article = "Created with the help from [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
|
72 |
|
73 |
# Create examples list from "examples/" directory
|
74 |
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
|
|
79 |
outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
|
80 |
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
|
81 |
# Create examples list from "examples/" directory
|
82 |
+
examples=example_list,
|
83 |
+
title=title,
|
84 |
+
description=description,
|
85 |
+
article=article)
|
|
|
|
|
86 |
|
87 |
# Launch the demo!
|
88 |
+
demo.launch()
|
model.py
CHANGED
@@ -20,6 +20,7 @@ def create_resnet50_model(num_classes:int=10, # 4
|
|
20 |
weights = torchvision.models.ResNet50_Weights.DEFAULT
|
21 |
transforms = weights.transforms()
|
22 |
model = torchvision.models.resnet50(weights=weights)
|
|
|
23 |
|
24 |
# 4. Freeze all layers in base model
|
25 |
for param in model.parameters():
|
|
|
20 |
weights = torchvision.models.ResNet50_Weights.DEFAULT
|
21 |
transforms = weights.transforms()
|
22 |
model = torchvision.models.resnet50(weights=weights)
|
23 |
+
model.fc = nn.Linear(2048, 10)
|
24 |
|
25 |
# 4. Freeze all layers in base model
|
26 |
for param in model.parameters():
|
pretrained_resnet50_feature_extractor_drappcompressed.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d945c30b1571afcba5b07e2bca41a6c744730d5d655ec8735dfab39dac5622c9
|
3 |
+
size 94528949
|