Spaces:
Sleeping
Sleeping
amazinghaha
commited on
Commit
•
88cd70c
1
Parent(s):
3797c8e
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import SimpleITK as sitk
|
6 |
+
from scipy.ndimage import zoom
|
7 |
+
from resnet_gn import resnet50
|
8 |
+
import pickle
|
9 |
+
def load_from_pkl(load_path):
|
10 |
+
data_input = open(load_path, 'rb')
|
11 |
+
read_data = pickle.load(data_input)
|
12 |
+
data_input.close()
|
13 |
+
return read_data
|
14 |
+
|
15 |
+
Image_3D = None
|
16 |
+
Current_name = None
|
17 |
+
ALL_message = load_from_pkl(r'./label0601.pkl')
|
18 |
+
|
19 |
+
Model_Paht = r'./model_epoch62.pth.tar'
|
20 |
+
checkpoint = torch.load(Model_Paht,map_location='cpu')
|
21 |
+
|
22 |
+
a = 5
|
23 |
+
classnet = resnet50(
|
24 |
+
num_classes=1,
|
25 |
+
sample_size=128,
|
26 |
+
sample_duration=8)
|
27 |
+
classnet.load_state_dict(checkpoint['model_dict'])
|
28 |
+
|
29 |
+
|
30 |
+
def resize3D(img, aimsize, order = 3):
|
31 |
+
"""
|
32 |
+
:param img: 3D array
|
33 |
+
:param aimsize: list, one or three elements, like [256], or [256,56,56]
|
34 |
+
:return:
|
35 |
+
"""
|
36 |
+
_shape =img.shape
|
37 |
+
if len(aimsize)==1:
|
38 |
+
aimsize = [aimsize[0] for _ in range(3)]
|
39 |
+
if aimsize[0] is None:
|
40 |
+
return zoom(img, (1, aimsize[1] / _shape[1], aimsize[2] / _shape[2]),order=order) # resample for cube_size
|
41 |
+
if aimsize[1] is None:
|
42 |
+
return zoom(img, (aimsize[0] / _shape[0], 1, aimsize[2] / _shape[2]),order=order) # resample for cube_size
|
43 |
+
if aimsize[2] is None:
|
44 |
+
return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], 1),order=order) # resample for cube_size
|
45 |
+
return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], aimsize[2] / _shape[2]), order=order) # resample for cube_size
|
46 |
+
|
47 |
+
def inference():
|
48 |
+
model = classnet
|
49 |
+
data = Image_3D
|
50 |
+
|
51 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
52 |
+
model.eval()
|
53 |
+
all_loss = 0
|
54 |
+
length = 0
|
55 |
+
with torch.no_grad():
|
56 |
+
data = torch.from_numpy(data)
|
57 |
+
image = torch.unsqueeze(data, 0)
|
58 |
+
patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in})
|
59 |
+
|
60 |
+
# Pre : Prediction Result
|
61 |
+
pre_probs = model(patch_data)
|
62 |
+
|
63 |
+
# pre_probs = F.sigmoid(pre_probs)#todo
|
64 |
+
pre_flat = pre_probs.view(-1)
|
65 |
+
a = 5
|
66 |
+
np.round(pre_flat.numpy()[0], decimals=2)
|
67 |
+
#(1-pre_flat.numpy()[0]).astype(np.float32)
|
68 |
+
#pre_flat.numpy()[0].astype(np.float32)
|
69 |
+
p = float(np.round(pre_flat.numpy()[0], decimals=2))
|
70 |
+
n = float(np.round(1-p, decimals=2))
|
71 |
+
return {'急性期': n, '亚急性期': p}
|
72 |
+
|
73 |
+
#
|
74 |
+
#
|
75 |
+
# def image_classifier(inp):
|
76 |
+
# #return {'cat': 0.3, 'dog': 0.7}
|
77 |
+
# return inp
|
78 |
+
#
|
79 |
+
# def image_read(inp):
|
80 |
+
# image = sitk.GetArrayFromImage(sitk.ReadImage(inp))
|
81 |
+
# ss = np.sum(image)
|
82 |
+
# return str(ss)
|
83 |
+
#
|
84 |
+
#
|
85 |
+
# def upload_file(files):
|
86 |
+
# file_paths = [file.name for file in files]
|
87 |
+
# return file_paths
|
88 |
+
#
|
89 |
+
# with gr.Blocks() as demo:
|
90 |
+
# file_output = gr.File()
|
91 |
+
# upload_button = gr.UploadButton("Click to Upload a File", file_types=["image", "video"], file_count="multiple")
|
92 |
+
# upload_button.upload(upload_file, upload_button, gr.Code(''))
|
93 |
+
# demo.launch()
|
94 |
+
|
95 |
+
import gradio as gr
|
96 |
+
import numpy as np
|
97 |
+
import nibabel as nib
|
98 |
+
import os
|
99 |
+
import tempfile
|
100 |
+
# 创建一个函数,接收3D数据并返回预测结果
|
101 |
+
def predict_3d(data):
|
102 |
+
# 在这里编写您的3D数据处理和预测逻辑
|
103 |
+
# 对于示例目的,这里只返回输入数据的最大值作为预测结果
|
104 |
+
result = np.max(data)
|
105 |
+
return result
|
106 |
+
|
107 |
+
# 创建一个用于读取和展示NIfTI数据的Gradio接口函数
|
108 |
+
def interface():
|
109 |
+
# 创建一个自定义输入组件,用于读取NIfTI数据
|
110 |
+
input_component = gr.inputs.File(label="Upload NIfTI file")
|
111 |
+
|
112 |
+
# 创建一个输出组件,用于展示预测结果
|
113 |
+
output_component = gr.outputs.Textbox()
|
114 |
+
|
115 |
+
# 定义预测函数,接收输入数据并调用predict_3d函数进行预测
|
116 |
+
def predict(input_file):
|
117 |
+
# 加载NIfTI数据
|
118 |
+
# temp_dir = tempfile.mkdtemp()
|
119 |
+
# temp_file = os.path.join(temp_dir, "temp_file")
|
120 |
+
# shutil.copyfile(file.name, temp_file)
|
121 |
+
nifti_data = nib.load(input_file.name)
|
122 |
+
# 将NIfTI数据转换为NumPy数组
|
123 |
+
data = np.array(nifti_data.dataobj)
|
124 |
+
# 在这里进行必要的数据预处理,例如缩放、归一化等
|
125 |
+
# 调用predict_3d函数进行预测
|
126 |
+
result = predict_3d(data)
|
127 |
+
# 将预测结果转换为字符串并返回
|
128 |
+
return str(result),str(result)
|
129 |
+
|
130 |
+
# 创建Gradio接口,将输入组件和输出组件传递给Interface函数
|
131 |
+
with gr.Box():
|
132 |
+
gr.Textbox(label="First")
|
133 |
+
gr.Textbox(label="Last")
|
134 |
+
iface_1 = gr.Interface(fn=predict, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=gr.Box)
|
135 |
+
|
136 |
+
return iface
|
137 |
+
|
138 |
+
|
139 |
+
def get_Image_reslice(input_file):
|
140 |
+
'''得到图像 返回随即层'''
|
141 |
+
global Image_3D
|
142 |
+
global Current_name
|
143 |
+
Image_3D = sitk.GetArrayFromImage(sitk.ReadImage(input_file.name))
|
144 |
+
Current_name = input_file.name.split(os.sep)[-1].split('.')[0].rsplit('_',1)[0]
|
145 |
+
Image_3D = (np.max(Image_3D)-Image_3D)/(np.max(Image_3D)-np.min(Image_3D))
|
146 |
+
random_z = np.random.randint(0, Image_3D.shape[0])
|
147 |
+
image_slice_z = Image_3D[random_z,:,:]
|
148 |
+
|
149 |
+
random_y = np.random.randint(0, Image_3D.shape[1])
|
150 |
+
image_slice_y = Image_3D[:, random_y, :]
|
151 |
+
|
152 |
+
random_x = np.random.randint(0, Image_3D.shape[2])
|
153 |
+
image_slice_x = Image_3D[:, :, random_x]
|
154 |
+
# return zoom(image_slice_z, (10 / image_slice_z.shape[0], 10 / image_slice_z.shape[1]), order=3) , \
|
155 |
+
# zoom(image_slice_y, (10 / image_slice_y.shape[0], 10 / image_slice_y.shape[1]), order=3), \
|
156 |
+
# zoom(image_slice_x, (10 / image_slice_x.shape[0], 10 / image_slice_x.shape[1]), order=3)
|
157 |
+
return image_slice_z, \
|
158 |
+
image_slice_y, \
|
159 |
+
image_slice_x, random_z,random_y,random_x
|
160 |
+
|
161 |
+
|
162 |
+
def change_image_slice_x(slice):
|
163 |
+
|
164 |
+
image_slice = Image_3D[:, :, slice-1]
|
165 |
+
return image_slice
|
166 |
+
|
167 |
+
def change_image_slice_y(slice):
|
168 |
+
image_slice = Image_3D[:, slice-1, :]
|
169 |
+
return image_slice
|
170 |
+
|
171 |
+
def change_image_slice_z(slice):
|
172 |
+
image_slice = Image_3D[slice-1,:,:]
|
173 |
+
return image_slice
|
174 |
+
|
175 |
+
def get_medical_message():
|
176 |
+
global Current_name
|
177 |
+
if Current_name==None:
|
178 |
+
return '请先加载数据',' '
|
179 |
+
else:
|
180 |
+
past = ALL_message[Current_name]['past']
|
181 |
+
now = ALL_message[Current_name]['now']
|
182 |
+
return past, now
|
183 |
+
|
184 |
+
|
185 |
+
class App:
|
186 |
+
def __init__(self):
|
187 |
+
self.demo = None
|
188 |
+
self.main()
|
189 |
+
def main(self):
|
190 |
+
# get_name = gr.Interface(lambda name: name, inputs="textbox", outputs="textbox")
|
191 |
+
# prepend_hello = gr.Interface(lambda name: f"Hello {name}!", inputs="textbox", outputs="textbox")
|
192 |
+
# append_nice = gr.Interface(lambda greeting: f"{greeting} Nice to meet you!",
|
193 |
+
# inputs="textbox", outputs=gr.Textbox(label="Greeting"))
|
194 |
+
|
195 |
+
#iface_1 = gr.Interface(fn=get_Image_reslice, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=[,gr.Image(shape=(5, 5)),gr.Image(shape=(5, 5))])
|
196 |
+
|
197 |
+
with gr.Blocks() as demo:
|
198 |
+
inp = gr.inputs.File(label="Upload NIfTI file")
|
199 |
+
btn1 = gr.Button("Upload Data")
|
200 |
+
with gr.Tab("Image"):
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column(scale=1):
|
203 |
+
out1 = gr.Image(shape=(10, 10))
|
204 |
+
slider1 = gr.Slider(1, 128, label='z轴层数', step=1, interactive=True)
|
205 |
+
with gr.Column(scale=1):
|
206 |
+
out2 = gr.Image(shape=(10, 10))
|
207 |
+
slider2 = gr.Slider(1, 256, label='y轴层数', step=1, interactive=True)
|
208 |
+
with gr.Column(scale=1):
|
209 |
+
out3 = gr.Image(shape=(10, 10))
|
210 |
+
slider3 = gr.Slider(1, 128, label='x轴层数', step=1, interactive=True)
|
211 |
+
btn1.click(get_Image_reslice, inp, [out1, out2, out3,slider1,slider2,slider3])
|
212 |
+
slider3.change(change_image_slice_x,inputs=slider3,outputs=out3)
|
213 |
+
slider2.change(change_image_slice_y, inputs=slider2, outputs=out2)
|
214 |
+
slider1.change(change_image_slice_z, inputs=slider1, outputs=out1)
|
215 |
+
|
216 |
+
|
217 |
+
with gr.Tab("Medical Information"):
|
218 |
+
with gr.Row():
|
219 |
+
with gr.Column(scale=1):
|
220 |
+
btn2 = gr.Button(label="临床信息")
|
221 |
+
out4 = gr.Textbox(label="患病史")
|
222 |
+
out6 = gr.Textbox(label="现病史")
|
223 |
+
|
224 |
+
with gr.Column(scale=1):
|
225 |
+
btn3 = gr.Button("分期结果")
|
226 |
+
out5 = gr.Label(num_top_classes=2,label='分期结果')
|
227 |
+
|
228 |
+
btn3.click(inference, inputs=None, outputs=out5)
|
229 |
+
btn2.click(get_medical_message, inputs=None, outputs=[out4,out6])
|
230 |
+
#demo = gr.Series(get_name, prepend_hello, append_nice)
|
231 |
+
|
232 |
+
demo.launch(share=True)
|
233 |
+
app = App()
|
234 |
+
# with gr.Blocks() as demo:
|
235 |
+
# with gr.Row():
|
236 |
+
# with gr.Column(scale=1):
|
237 |
+
# text1 = gr.Textbox()
|
238 |
+
# text2 = gr.Textbox()
|
239 |
+
# with gr.Column(scale=4):
|
240 |
+
# btn1 = gr.Button("Button 1")
|
241 |
+
# btn2 = gr.Button("Button 2")
|
242 |
+
# demo.launch()
|
243 |
+
|