fhatje commited on
Commit
00f809d
·
1 Parent(s): 7ffa4f6

Minor text changes.

Browse files
Files changed (1) hide show
  1. app.py +66 -29
app.py CHANGED
@@ -1,9 +1,27 @@
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../main.ipynb.
2
 
3
  # %% auto 0
4
- __all__ = ['ORGAN', 'IMAGE_SIZE', 'MODEL_NAME', 'THRESHOLD', 'CODES', 'learn', 'title', 'description', 'examples',
5
- 'interpretation', 'demo', 'x_getter', 'y_getter', 'splitter', 'make3D', 'predict', 'infer',
6
- 'remove_small_segs', 'to_oberlay_image']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # %% ../main.ipynb 1
9
  import numpy as np
@@ -18,77 +36,96 @@ import gradio as gr
18
  ORGAN = "kidney"
19
  IMAGE_SIZE = 512
20
  MODEL_NAME = "unetpp_b4_th60_d9414.pkl"
21
- THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100.
22
- CODES = ["Background", "FTU"] # FTU = functional tissue unit
 
23
 
24
  # %% ../main.ipynb 3
25
- def x_getter(r): return r["fnames"]
26
- def y_getter(r):
 
 
 
27
  rle = r["rle"]
28
  shape = (int(r["img_height"]), int(r["img_width"]))
29
  return rle_decode(rle, shape).T
30
- def splitter(model):
 
 
31
  enc_params = L(model.encoder.parameters())
32
  dec_params = L(model.decoder.parameters())
33
  sg_params = L(model.segmentation_head.parameters())
34
  untrained_params = L([*dec_params, *sg_params])
35
  return L([enc_params, untrained_params])
36
 
 
37
  # %% ../main.ipynb 4
38
  learn = load_learner(MODEL_NAME)
39
 
 
40
  # %% ../main.ipynb 5
41
  def make3D(t: np.array) -> np.array:
42
  t = np.expand_dims(t, axis=2)
43
- t = np.concatenate((t,t,t), axis=2)
44
  return t
45
 
 
46
  def predict(fn, cutoff_area=200):
47
  data = infer(fn)
48
  data = remove_small_segs(data, cutoff_area=cutoff_area)
49
  return to_oberlay_image(data), data["df"]
50
 
 
51
  def infer(fn):
52
  img = PILImage.create(fn)
53
- tf_img,_,_,preds = learn.predict(img, with_input=True)
54
- mask = (F.softmax(preds.float(), dim=0)>THRESHOLD).int()[1]
55
  mask = np.array(mask, dtype=np.uint8)
56
- resized_image = Image.fromarray(tf_img.numpy().transpose(1, 2, 0).astype(np.uint8)).resize(img.shape)
 
 
57
  resized_image = np.array(resized_image)
58
  return {
59
  "tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8),
60
- "tf_mask": mask
61
  }
62
 
 
63
  def remove_small_segs(data, cutoff_area=250):
64
  labeled_mask = skimage.measure.label(data["tf_mask"])
65
  props = skimage.measure.regionprops(labeled_mask)
66
- df = {"Glomerulus":[], "Area (in px)":[]}
67
  for i, prop in enumerate(props):
68
- if prop.area < cutoff_area:
69
- labeled_mask[labeled_mask==i+1] = 0
70
  continue
71
- df["Glomerulus"].append(len(df["Glomerulus"]) + 1)
72
  df["Area (in px)"].append(prop.area)
73
- labeled_mask[labeled_mask>0] = 1
74
  data["tf_mask"] = labeled_mask.astype(np.uint8)
75
  data["df"] = pd.DataFrame(df)
76
  return data
77
 
 
78
  def to_oberlay_image(data):
79
  img, msk = data["tf_image"], data["tf_mask"]
80
  msk_im = np.zeros_like(img)
81
  # rgb code: 255, 80, 80
82
- msk_im[:,:,0] = 255
83
- msk_im[:,:,1] = 80
84
- msk_im[:,:,2] = 80
85
  img = Image.fromarray(img).convert("RGBA")
86
  msk_im = Image.fromarray(msk_im).convert("RGBA")
87
- msk = Image.fromarray((msk*255*0.5).astype(np.uint8))
88
 
89
- img.paste(msk_im, (0, 0), msk, )
 
 
 
 
90
  return img
91
 
 
92
  # %% ../main.ipynb 6
93
  title = "Glomerulus Segmentation"
94
  description = """
@@ -98,11 +135,11 @@ The model deployed here is a [UNet++](https://arxiv.org/abs/1807.10165) with an
98
 
99
  The provided example images are random subset of kidney slices from the [Human Protein Atlas](https://www.proteinatlas.org/). These have been collected separately from model training and have neither been part of the training, validation nor test set.
100
 
101
- See corresponding [blog post](https://fhatje.github.io/posts/glomseg/train_model.html).
102
  """
103
- #article="<p style='text-align: center'><a href='Blog post URL' target='_blank'>Blog post</a></p>"
104
  examples = [str(p) for p in get_image_files("example_images")]
105
- interpretation='default'
106
 
107
  # %% ../main.ipynb 7
108
  demo = gr.Interface(
@@ -113,10 +150,10 @@ demo = gr.Interface(
113
  description=description,
114
  examples=examples,
115
  interpretation=interpretation,
116
- # Fixes error when set to True:
117
- # https://github.com/gradio-app/gradio/pull/1949
118
  # but generated file names are too long
119
- _api_mode=False
120
  )
121
 
122
  # %% ../main.ipynb 9
 
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../main.ipynb.
2
 
3
  # %% auto 0
4
+ __all__ = [
5
+ "ORGAN",
6
+ "IMAGE_SIZE",
7
+ "MODEL_NAME",
8
+ "THRESHOLD",
9
+ "CODES",
10
+ "learn",
11
+ "title",
12
+ "description",
13
+ "examples",
14
+ "interpretation",
15
+ "demo",
16
+ "x_getter",
17
+ "y_getter",
18
+ "splitter",
19
+ "make3D",
20
+ "predict",
21
+ "infer",
22
+ "remove_small_segs",
23
+ "to_oberlay_image",
24
+ ]
25
 
26
  # %% ../main.ipynb 1
27
  import numpy as np
 
36
  ORGAN = "kidney"
37
  IMAGE_SIZE = 512
38
  MODEL_NAME = "unetpp_b4_th60_d9414.pkl"
39
+ THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100.0
40
+ CODES = ["Background", "FTU"] # FTU = functional tissue unit
41
+
42
 
43
  # %% ../main.ipynb 3
44
+ def x_getter(r):
45
+ return r["fnames"]
46
+
47
+
48
+ def y_getter(r):
49
  rle = r["rle"]
50
  shape = (int(r["img_height"]), int(r["img_width"]))
51
  return rle_decode(rle, shape).T
52
+
53
+
54
+ def splitter(model):
55
  enc_params = L(model.encoder.parameters())
56
  dec_params = L(model.decoder.parameters())
57
  sg_params = L(model.segmentation_head.parameters())
58
  untrained_params = L([*dec_params, *sg_params])
59
  return L([enc_params, untrained_params])
60
 
61
+
62
  # %% ../main.ipynb 4
63
  learn = load_learner(MODEL_NAME)
64
 
65
+
66
  # %% ../main.ipynb 5
67
  def make3D(t: np.array) -> np.array:
68
  t = np.expand_dims(t, axis=2)
69
+ t = np.concatenate((t, t, t), axis=2)
70
  return t
71
 
72
+
73
  def predict(fn, cutoff_area=200):
74
  data = infer(fn)
75
  data = remove_small_segs(data, cutoff_area=cutoff_area)
76
  return to_oberlay_image(data), data["df"]
77
 
78
+
79
  def infer(fn):
80
  img = PILImage.create(fn)
81
+ tf_img, _, _, preds = learn.predict(img, with_input=True)
82
+ mask = (F.softmax(preds.float(), dim=0) > THRESHOLD).int()[1]
83
  mask = np.array(mask, dtype=np.uint8)
84
+ resized_image = Image.fromarray(
85
+ tf_img.numpy().transpose(1, 2, 0).astype(np.uint8)
86
+ ).resize(img.shape)
87
  resized_image = np.array(resized_image)
88
  return {
89
  "tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8),
90
+ "tf_mask": mask,
91
  }
92
 
93
+
94
  def remove_small_segs(data, cutoff_area=250):
95
  labeled_mask = skimage.measure.label(data["tf_mask"])
96
  props = skimage.measure.regionprops(labeled_mask)
97
+ df = {"Glomerulus": [], "Area (in px)": []}
98
  for i, prop in enumerate(props):
99
+ if prop.area < cutoff_area:
100
+ labeled_mask[labeled_mask == i + 1] = 0
101
  continue
102
+ df["Glomerulus"].append(len(df["Glomerulus"]) + 1)
103
  df["Area (in px)"].append(prop.area)
104
+ labeled_mask[labeled_mask > 0] = 1
105
  data["tf_mask"] = labeled_mask.astype(np.uint8)
106
  data["df"] = pd.DataFrame(df)
107
  return data
108
 
109
+
110
  def to_oberlay_image(data):
111
  img, msk = data["tf_image"], data["tf_mask"]
112
  msk_im = np.zeros_like(img)
113
  # rgb code: 255, 80, 80
114
+ msk_im[:, :, 0] = 255
115
+ msk_im[:, :, 1] = 80
116
+ msk_im[:, :, 2] = 80
117
  img = Image.fromarray(img).convert("RGBA")
118
  msk_im = Image.fromarray(msk_im).convert("RGBA")
119
+ msk = Image.fromarray((msk * 255 * 0.5).astype(np.uint8))
120
 
121
+ img.paste(
122
+ msk_im,
123
+ (0, 0),
124
+ msk,
125
+ )
126
  return img
127
 
128
+
129
  # %% ../main.ipynb 6
130
  title = "Glomerulus Segmentation"
131
  description = """
 
135
 
136
  The provided example images are random subset of kidney slices from the [Human Protein Atlas](https://www.proteinatlas.org/). These have been collected separately from model training and have neither been part of the training, validation nor test set.
137
 
138
+ Here is my corresponding [blog post](https://fhatje.github.io/posts/glomseg/train_model.html).
139
  """
140
+ # article="<p style='text-align: center'><a href='Blog post URL' target='_blank'>Blog post</a></p>"
141
  examples = [str(p) for p in get_image_files("example_images")]
142
+ interpretation = "default"
143
 
144
  # %% ../main.ipynb 7
145
  demo = gr.Interface(
 
150
  description=description,
151
  examples=examples,
152
  interpretation=interpretation,
153
+ # Fixes error when set to True:
154
+ # https://github.com/gradio-app/gradio/pull/1949
155
  # but generated file names are too long
156
+ _api_mode=False,
157
  )
158
 
159
  # %% ../main.ipynb 9