amazinghaha commited on
Commit
88cd70c
1 Parent(s): 3797c8e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import numpy as np
5
+ import SimpleITK as sitk
6
+ from scipy.ndimage import zoom
7
+ from resnet_gn import resnet50
8
+ import pickle
9
+ def load_from_pkl(load_path):
10
+ data_input = open(load_path, 'rb')
11
+ read_data = pickle.load(data_input)
12
+ data_input.close()
13
+ return read_data
14
+
15
+ Image_3D = None
16
+ Current_name = None
17
+ ALL_message = load_from_pkl(r'./label0601.pkl')
18
+
19
+ Model_Paht = r'./model_epoch62.pth.tar'
20
+ checkpoint = torch.load(Model_Paht,map_location='cpu')
21
+
22
+ a = 5
23
+ classnet = resnet50(
24
+ num_classes=1,
25
+ sample_size=128,
26
+ sample_duration=8)
27
+ classnet.load_state_dict(checkpoint['model_dict'])
28
+
29
+
30
+ def resize3D(img, aimsize, order = 3):
31
+ """
32
+ :param img: 3D array
33
+ :param aimsize: list, one or three elements, like [256], or [256,56,56]
34
+ :return:
35
+ """
36
+ _shape =img.shape
37
+ if len(aimsize)==1:
38
+ aimsize = [aimsize[0] for _ in range(3)]
39
+ if aimsize[0] is None:
40
+ return zoom(img, (1, aimsize[1] / _shape[1], aimsize[2] / _shape[2]),order=order) # resample for cube_size
41
+ if aimsize[1] is None:
42
+ return zoom(img, (aimsize[0] / _shape[0], 1, aimsize[2] / _shape[2]),order=order) # resample for cube_size
43
+ if aimsize[2] is None:
44
+ return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], 1),order=order) # resample for cube_size
45
+ return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], aimsize[2] / _shape[2]), order=order) # resample for cube_size
46
+
47
+ def inference():
48
+ model = classnet
49
+ data = Image_3D
50
+
51
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+ model.eval()
53
+ all_loss = 0
54
+ length = 0
55
+ with torch.no_grad():
56
+ data = torch.from_numpy(data)
57
+ image = torch.unsqueeze(data, 0)
58
+ patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in})
59
+
60
+ # Pre : Prediction Result
61
+ pre_probs = model(patch_data)
62
+
63
+ # pre_probs = F.sigmoid(pre_probs)#todo
64
+ pre_flat = pre_probs.view(-1)
65
+ a = 5
66
+ np.round(pre_flat.numpy()[0], decimals=2)
67
+ #(1-pre_flat.numpy()[0]).astype(np.float32)
68
+ #pre_flat.numpy()[0].astype(np.float32)
69
+ p = float(np.round(pre_flat.numpy()[0], decimals=2))
70
+ n = float(np.round(1-p, decimals=2))
71
+ return {'急性期': n, '亚急性期': p}
72
+
73
+ #
74
+ #
75
+ # def image_classifier(inp):
76
+ # #return {'cat': 0.3, 'dog': 0.7}
77
+ # return inp
78
+ #
79
+ # def image_read(inp):
80
+ # image = sitk.GetArrayFromImage(sitk.ReadImage(inp))
81
+ # ss = np.sum(image)
82
+ # return str(ss)
83
+ #
84
+ #
85
+ # def upload_file(files):
86
+ # file_paths = [file.name for file in files]
87
+ # return file_paths
88
+ #
89
+ # with gr.Blocks() as demo:
90
+ # file_output = gr.File()
91
+ # upload_button = gr.UploadButton("Click to Upload a File", file_types=["image", "video"], file_count="multiple")
92
+ # upload_button.upload(upload_file, upload_button, gr.Code(''))
93
+ # demo.launch()
94
+
95
+ import gradio as gr
96
+ import numpy as np
97
+ import nibabel as nib
98
+ import os
99
+ import tempfile
100
+ # 创建一个函数,接收3D数据并返回预测结果
101
+ def predict_3d(data):
102
+ # 在这里编写您的3D数据处理和预测逻辑
103
+ # 对于示例目的,这里只返回输入数据的最大值作为预测结果
104
+ result = np.max(data)
105
+ return result
106
+
107
+ # 创建一个用于读取和展示NIfTI数据的Gradio接口函数
108
+ def interface():
109
+ # 创建一个自定义输入组件,用于读取NIfTI数据
110
+ input_component = gr.inputs.File(label="Upload NIfTI file")
111
+
112
+ # 创建一个输出组件,用于展示预测结果
113
+ output_component = gr.outputs.Textbox()
114
+
115
+ # 定义预测函数,接收输入数据并调用predict_3d函数进行预测
116
+ def predict(input_file):
117
+ # 加载NIfTI数据
118
+ # temp_dir = tempfile.mkdtemp()
119
+ # temp_file = os.path.join(temp_dir, "temp_file")
120
+ # shutil.copyfile(file.name, temp_file)
121
+ nifti_data = nib.load(input_file.name)
122
+ # 将NIfTI数据转换为NumPy数组
123
+ data = np.array(nifti_data.dataobj)
124
+ # 在这里进行必要的数据预处理,例如缩放、归一化等
125
+ # 调用predict_3d函数进行预测
126
+ result = predict_3d(data)
127
+ # 将预测结果转换为字符串并返回
128
+ return str(result),str(result)
129
+
130
+ # 创建Gradio接口,将输入组件和输出组件传递给Interface函数
131
+ with gr.Box():
132
+ gr.Textbox(label="First")
133
+ gr.Textbox(label="Last")
134
+ iface_1 = gr.Interface(fn=predict, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=gr.Box)
135
+
136
+ return iface
137
+
138
+
139
+ def get_Image_reslice(input_file):
140
+ '''得到图像 返回随即层'''
141
+ global Image_3D
142
+ global Current_name
143
+ Image_3D = sitk.GetArrayFromImage(sitk.ReadImage(input_file.name))
144
+ Current_name = input_file.name.split(os.sep)[-1].split('.')[0].rsplit('_',1)[0]
145
+ Image_3D = (np.max(Image_3D)-Image_3D)/(np.max(Image_3D)-np.min(Image_3D))
146
+ random_z = np.random.randint(0, Image_3D.shape[0])
147
+ image_slice_z = Image_3D[random_z,:,:]
148
+
149
+ random_y = np.random.randint(0, Image_3D.shape[1])
150
+ image_slice_y = Image_3D[:, random_y, :]
151
+
152
+ random_x = np.random.randint(0, Image_3D.shape[2])
153
+ image_slice_x = Image_3D[:, :, random_x]
154
+ # return zoom(image_slice_z, (10 / image_slice_z.shape[0], 10 / image_slice_z.shape[1]), order=3) , \
155
+ # zoom(image_slice_y, (10 / image_slice_y.shape[0], 10 / image_slice_y.shape[1]), order=3), \
156
+ # zoom(image_slice_x, (10 / image_slice_x.shape[0], 10 / image_slice_x.shape[1]), order=3)
157
+ return image_slice_z, \
158
+ image_slice_y, \
159
+ image_slice_x, random_z,random_y,random_x
160
+
161
+
162
+ def change_image_slice_x(slice):
163
+
164
+ image_slice = Image_3D[:, :, slice-1]
165
+ return image_slice
166
+
167
+ def change_image_slice_y(slice):
168
+ image_slice = Image_3D[:, slice-1, :]
169
+ return image_slice
170
+
171
+ def change_image_slice_z(slice):
172
+ image_slice = Image_3D[slice-1,:,:]
173
+ return image_slice
174
+
175
+ def get_medical_message():
176
+ global Current_name
177
+ if Current_name==None:
178
+ return '请先加载数据',' '
179
+ else:
180
+ past = ALL_message[Current_name]['past']
181
+ now = ALL_message[Current_name]['now']
182
+ return past, now
183
+
184
+
185
+ class App:
186
+ def __init__(self):
187
+ self.demo = None
188
+ self.main()
189
+ def main(self):
190
+ # get_name = gr.Interface(lambda name: name, inputs="textbox", outputs="textbox")
191
+ # prepend_hello = gr.Interface(lambda name: f"Hello {name}!", inputs="textbox", outputs="textbox")
192
+ # append_nice = gr.Interface(lambda greeting: f"{greeting} Nice to meet you!",
193
+ # inputs="textbox", outputs=gr.Textbox(label="Greeting"))
194
+
195
+ #iface_1 = gr.Interface(fn=get_Image_reslice, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=[,gr.Image(shape=(5, 5)),gr.Image(shape=(5, 5))])
196
+
197
+ with gr.Blocks() as demo:
198
+ inp = gr.inputs.File(label="Upload NIfTI file")
199
+ btn1 = gr.Button("Upload Data")
200
+ with gr.Tab("Image"):
201
+ with gr.Row():
202
+ with gr.Column(scale=1):
203
+ out1 = gr.Image(shape=(10, 10))
204
+ slider1 = gr.Slider(1, 128, label='z轴层数', step=1, interactive=True)
205
+ with gr.Column(scale=1):
206
+ out2 = gr.Image(shape=(10, 10))
207
+ slider2 = gr.Slider(1, 256, label='y轴层数', step=1, interactive=True)
208
+ with gr.Column(scale=1):
209
+ out3 = gr.Image(shape=(10, 10))
210
+ slider3 = gr.Slider(1, 128, label='x轴层数', step=1, interactive=True)
211
+ btn1.click(get_Image_reslice, inp, [out1, out2, out3,slider1,slider2,slider3])
212
+ slider3.change(change_image_slice_x,inputs=slider3,outputs=out3)
213
+ slider2.change(change_image_slice_y, inputs=slider2, outputs=out2)
214
+ slider1.change(change_image_slice_z, inputs=slider1, outputs=out1)
215
+
216
+
217
+ with gr.Tab("Medical Information"):
218
+ with gr.Row():
219
+ with gr.Column(scale=1):
220
+ btn2 = gr.Button(label="临床信息")
221
+ out4 = gr.Textbox(label="患病史")
222
+ out6 = gr.Textbox(label="现病史")
223
+
224
+ with gr.Column(scale=1):
225
+ btn3 = gr.Button("分期结果")
226
+ out5 = gr.Label(num_top_classes=2,label='分期结果')
227
+
228
+ btn3.click(inference, inputs=None, outputs=out5)
229
+ btn2.click(get_medical_message, inputs=None, outputs=[out4,out6])
230
+ #demo = gr.Series(get_name, prepend_hello, append_nice)
231
+
232
+ demo.launch(share=True)
233
+ app = App()
234
+ # with gr.Blocks() as demo:
235
+ # with gr.Row():
236
+ # with gr.Column(scale=1):
237
+ # text1 = gr.Textbox()
238
+ # text2 = gr.Textbox()
239
+ # with gr.Column(scale=4):
240
+ # btn1 = gr.Button("Button 1")
241
+ # btn2 = gr.Button("Button 2")
242
+ # demo.launch()
243
+