Jfink09 commited on
Commit
6584f0f
1 Parent(s): 5beb3d6

Upload 7 files

Browse files
app.py CHANGED
@@ -25,6 +25,7 @@ class_names = ['CRVO',
25
  resnet50, resnet50_transforms = create_resnet50_model(
26
  num_classes=len(class_names), # actual value would also work
27
  )
 
28
 
29
  # Load saved weights
30
  resnet50.load_state_dict(
@@ -34,6 +35,7 @@ resnet50.load_state_dict(
34
  )
35
  )
36
 
 
37
  ### 3. Predict function ###
38
 
39
  # Create predict function
@@ -64,9 +66,9 @@ def predict(img) -> Tuple[Dict, float]:
64
  ### 4. Gradio app ###
65
 
66
  # Create title, description and article strings
67
- #title = "Retinal Disease Detection"
68
- #description = "A ResNet50 feature extractor computer vision model to classify funduscopic images."
69
- #article = "Created with the help from [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
70
 
71
  # Create examples list from "examples/" directory
72
  example_list = [["examples/" + example] for example in os.listdir("examples")]
@@ -77,12 +79,10 @@ demo = gr.Interface(fn=predict, # mapping function from input to output
77
  outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
78
  gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
79
  # Create examples list from "examples/" directory
80
- #title=title,
81
- examples=example_list,
82
- css="""
83
- .gradio-container {background-color: #0B0F19}
84
- .mx-auto {background-color: #0B0F19}
85
- """)
86
 
87
  # Launch the demo!
88
- demo.launch()
 
25
  resnet50, resnet50_transforms = create_resnet50_model(
26
  num_classes=len(class_names), # actual value would also work
27
  )
28
+ resnet50.fc = nn.Linear(2048, 10)
29
 
30
  # Load saved weights
31
  resnet50.load_state_dict(
 
35
  )
36
  )
37
 
38
+
39
  ### 3. Predict function ###
40
 
41
  # Create predict function
 
66
  ### 4. Gradio app ###
67
 
68
  # Create title, description and article strings
69
+ title = "DeepFundus 👀"
70
+ description = "A ResNet50 feature extractor computer vision model to classify funduscopic images."
71
+ article = "Created with the help from [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
72
 
73
  # Create examples list from "examples/" directory
74
  example_list = [["examples/" + example] for example in os.listdir("examples")]
 
79
  outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
80
  gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
81
  # Create examples list from "examples/" directory
82
+ examples=example_list,
83
+ title=title,
84
+ description=description,
85
+ article=article)
 
 
86
 
87
  # Launch the demo!
88
+ demo.launch()
model.py CHANGED
@@ -20,6 +20,7 @@ def create_resnet50_model(num_classes:int=10, # 4
20
  weights = torchvision.models.ResNet50_Weights.DEFAULT
21
  transforms = weights.transforms()
22
  model = torchvision.models.resnet50(weights=weights)
 
23
 
24
  # 4. Freeze all layers in base model
25
  for param in model.parameters():
 
20
  weights = torchvision.models.ResNet50_Weights.DEFAULT
21
  transforms = weights.transforms()
22
  model = torchvision.models.resnet50(weights=weights)
23
+ model.fc = nn.Linear(2048, 10)
24
 
25
  # 4. Freeze all layers in base model
26
  for param in model.parameters():
pretrained_resnet50_feature_extractor_drappcompressed.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7c57c711600cdb53a356c4c8de7b53c2e3dd12e8a44b8080d60fca2b412ba80
3
- size 102643061
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d945c30b1571afcba5b07e2bca41a6c744730d5d655ec8735dfab39dac5622c9
3
+ size 94528949