ianpan commited on
Commit
9f889ce
Β·
1 Parent(s): 7ce2ffd
app.py CHANGED
@@ -22,19 +22,23 @@ class Net2D(nn.Module):
22
  return x[:, 0] if x.size(1) == 1 else x
23
 
24
 
 
 
 
 
 
 
 
25
  weights = torch.load("model0.ckpt", map_location=torch.device("cpu"))["state_dict"]
26
- weights = {k.replace("model.", "") : v for k, v in weights.items()}
27
- model = Net2D(weights)
28
 
29
 
30
  def predict(Image):
31
  img = torch.from_numpy(Image)
32
- img = img[:, :, [2, 1, 0]]
33
  img = img.permute(2, 0, 1)
34
  img = img.unsqueeze(0)
35
- img = img / img.max()
36
- img = img - 0.5
37
- img = img * 2.0
38
  with torch.no_grad():
39
  grade = torch.softmax(model(img.float()), dim=1)[0]
40
  cats = ["None", "Mild", "Moderate", "Severe", "Proliferative"]
@@ -44,6 +48,7 @@ def predict(Image):
44
  return output_dict
45
 
46
 
 
47
  image = gr.Image(shape=(512, 512), image_mode="RGB")
48
  label = gr.Label(label="Grade")
49
 
@@ -51,12 +56,11 @@ demo = gr.Interface(
51
  fn=predict,
52
  inputs=image,
53
  outputs=label,
54
- examples=["examples/none.png", "examples/mild.png", "examples/moderate.png", "examples/severe.png",
55
- "examples/proliferative.png"]
56
  )
57
 
58
 
59
  if __name__ == "__main__":
60
- demo.launch(debug=True)
61
 
62
 
 
22
  return x[:, 0] if x.size(1) == 1 else x
23
 
24
 
25
+ def rescale(x):
26
+ x = x / 255.0
27
+ x = x - 0.5
28
+ x = x * 2.0
29
+ return x
30
+
31
+
32
  weights = torch.load("model0.ckpt", map_location=torch.device("cpu"))["state_dict"]
33
+ weights = {k.replace("model.", ""): v for k, v in weights.items()}
34
+ model = Net2D(weights).eval()
35
 
36
 
37
  def predict(Image):
38
  img = torch.from_numpy(Image)
 
39
  img = img.permute(2, 0, 1)
40
  img = img.unsqueeze(0)
41
+ img = rescale(img)
 
 
42
  with torch.no_grad():
43
  grade = torch.softmax(model(img.float()), dim=1)[0]
44
  cats = ["None", "Mild", "Moderate", "Severe", "Proliferative"]
 
48
  return output_dict
49
 
50
 
51
+
52
  image = gr.Image(shape=(512, 512), image_mode="RGB")
53
  label = gr.Label(label="Grade")
54
 
 
56
  fn=predict,
57
  inputs=image,
58
  outputs=label,
59
+ examples=["examples/0.png", "examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png"]
 
60
  )
61
 
62
 
63
  if __name__ == "__main__":
64
+ demo.launch(debug=True, share=True)
65
 
66
 
examples/{none.png β†’ 0.png} RENAMED
File without changes
examples/{mild.png β†’ 1.png} RENAMED
File without changes
examples/{moderate.png β†’ 2.png} RENAMED
File without changes
examples/{proliferative.png β†’ 3.png} RENAMED
File without changes
examples/4.png ADDED

Git LFS Details

  • SHA256: 2eda1f75fa2f9b5bcfa096f5e9b90b9a519da0350e6d1a52602c3d7cf2e25abd
  • Pointer size: 132 Bytes
  • Size of remote file: 5.51 MB
examples/severe.png DELETED

Git LFS Details

  • SHA256: 96dfa76f3ba7e3a61855d21a42a9c4dc2228fd1c8e0e1622c6be1ef12e6b6d7e
  • Pointer size: 132 Bytes
  • Size of remote file: 8.67 MB