ipd commited on
Commit
36a45f4
·
verified ·
1 Parent(s): 626f2e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -7,12 +7,10 @@ from PIL import Image
7
 
8
  from torch.utils.mobile_optimizer import optimize_for_mobile
9
 
10
- model = timm.create_model('vit_base_patch16_224', pretrained=True)
11
- model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=5)
12
-
13
- #path = "opt_model.pt"
14
-
15
- #model = model.jit.load(path)
16
 
17
  model.eval()
18
 
@@ -21,9 +19,9 @@ def transform_image(img_sample):
21
  transforms.Resize((224, 224)), # Resize to 224x224
22
  transforms.ToTensor(), # Convert PIL image to tensor
23
  transforms.ColorJitter(contrast=0.5), # Contrast
24
- transforms.RandomAdjustSharpness(sharpness_factor=0.5),
25
- transforms.RandomSolarize(threshold=0.75),
26
- transforms.RandomAutocontrast(p=1),
27
  ])
28
  transformed_img = transform(img_sample)
29
  return transformed_img
@@ -40,7 +38,7 @@ def predict(Image):
40
 
41
  with torch.no_grad():
42
  grade = torch.softmax(model(img.float()), dim=1)[0]
43
- category = ["Normal", "Mild", "Moderate", "Severe", "Proliferative"]
44
  output_dict = {}
45
  for cat, value in zip(category, grade):
46
  output_dict[cat] = value.item()
 
7
 
8
  from torch.utils.mobile_optimizer import optimize_for_mobile
9
 
10
+ model = timm.create_model('resnet50', pretrained=True)
11
+ model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=5)
12
+ path = "epoch_4_Resnet50-0.5contrast.pth"
13
+ model.load_state_dict(torch.load(path))
 
 
14
 
15
  model.eval()
16
 
 
19
  transforms.Resize((224, 224)), # Resize to 224x224
20
  transforms.ToTensor(), # Convert PIL image to tensor
21
  transforms.ColorJitter(contrast=0.5), # Contrast
22
+ #transforms.RandomAdjustSharpness(sharpness_factor=0.5),
23
+ #transforms.RandomSolarize(threshold=0.75),
24
+ #transforms.RandomAutocontrast(p=1),
25
  ])
26
  transformed_img = transform(img_sample)
27
  return transformed_img
 
38
 
39
  with torch.no_grad():
40
  grade = torch.softmax(model(img.float()), dim=1)[0]
41
+ category = ["0 - Normal", "1 - Mild", "2 - Moderate", "3 - Severe", "4 - Proliferative"]
42
  output_dict = {}
43
  for cat, value in zip(category, grade):
44
  output_dict[cat] = value.item()