Jfink09 commited on
Commit
e3e851d
1 Parent(s): 6584f0f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1 -2
  2. model.py +0 -1
app.py CHANGED
@@ -16,7 +16,7 @@ class_names = ['CRVO',
16
  'Macular Hole',
17
  'Myelinated Nerve Fiber',
18
  'Normal',
19
- 'Pathological Mypoia',
20
  'Retinitis Pigmentosa']
21
 
22
  ### 2. Model and transforms preparation ###
@@ -25,7 +25,6 @@ class_names = ['CRVO',
25
  resnet50, resnet50_transforms = create_resnet50_model(
26
  num_classes=len(class_names), # actual value would also work
27
  )
28
- resnet50.fc = nn.Linear(2048, 10)
29
 
30
  # Load saved weights
31
  resnet50.load_state_dict(
 
16
  'Macular Hole',
17
  'Myelinated Nerve Fiber',
18
  'Normal',
19
+ 'Pathological Myopia',
20
  'Retinitis Pigmentosa']
21
 
22
  ### 2. Model and transforms preparation ###
 
25
  resnet50, resnet50_transforms = create_resnet50_model(
26
  num_classes=len(class_names), # actual value would also work
27
  )
 
28
 
29
  # Load saved weights
30
  resnet50.load_state_dict(
model.py CHANGED
@@ -20,7 +20,6 @@ def create_resnet50_model(num_classes:int=10, # 4
20
  weights = torchvision.models.ResNet50_Weights.DEFAULT
21
  transforms = weights.transforms()
22
  model = torchvision.models.resnet50(weights=weights)
23
- model.fc = nn.Linear(2048, 10)
24
 
25
  # 4. Freeze all layers in base model
26
  for param in model.parameters():
 
20
  weights = torchvision.models.ResNet50_Weights.DEFAULT
21
  transforms = weights.transforms()
22
  model = torchvision.models.resnet50(weights=weights)
 
23
 
24
  # 4. Freeze all layers in base model
25
  for param in model.parameters():