support cpu
Browse files- README.md +3 -3
- __pycache__/utils.cpython-39.pyc +0 -0
- demo.py +10 -42
- inference.py +6 -6
- utils.py +7 -8
README.md
CHANGED
@@ -22,12 +22,12 @@ conda activate medversa
|
|
22 |
## Inference
|
23 |
``` python
|
24 |
from utils import *
|
|
|
25 |
|
26 |
# --- Launch Model ---
|
27 |
-
device = 'cuda
|
28 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
29 |
-
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
30 |
-
model.eval()
|
31 |
|
32 |
# --- Define examples ---
|
33 |
examples = [
|
|
|
22 |
## Inference
|
23 |
``` python
|
24 |
from utils import *
|
25 |
+
from torch import cuda
|
26 |
|
27 |
# --- Launch Model ---
|
28 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
29 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
30 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
|
|
|
31 |
|
32 |
# --- Define examples ---
|
33 |
examples = [
|
__pycache__/utils.cpython-39.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-39.pyc and b/__pycache__/utils.cpython-39.pyc differ
|
|
demo.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import argparse
|
3 |
import torch
|
|
|
4 |
import torch.nn.functional as F
|
5 |
import torchvision.transforms.functional as TF
|
6 |
from torchvision import transforms
|
@@ -32,15 +33,14 @@ def parse_args():
|
|
32 |
args = parser.parse_args()
|
33 |
return args
|
34 |
|
35 |
-
device = 'cuda
|
36 |
# Launch model
|
37 |
args = parse_args()
|
38 |
cfg = Config(args)
|
39 |
|
40 |
model_config = cfg.model_cfg
|
41 |
model_cls = registry.get_model_class(model_config.arch)
|
42 |
-
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
43 |
-
model.eval()
|
44 |
global global_images
|
45 |
global_images = None
|
46 |
|
@@ -146,7 +146,7 @@ def task_seg_2d(model, preds, hidden_states, image):
|
|
146 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
147 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
148 |
seg_probs = F.sigmoid(seg_preds)
|
149 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
150 |
return seg_mask
|
151 |
else:
|
152 |
return None
|
@@ -165,7 +165,7 @@ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
|
|
165 |
new_img_embeds_list[-1] = last_feats
|
166 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
167 |
seg_probs = F.sigmoid(seg_preds)
|
168 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
169 |
return seg_mask
|
170 |
|
171 |
def task_det_2d(model, preds, hidden_states):
|
@@ -175,7 +175,7 @@ def task_det_2d(model, preds, hidden_states):
|
|
175 |
if target_states:
|
176 |
target_states = torch.cat(target_states).squeeze()
|
177 |
det_states = model.text_det(target_states).detach().cpu()
|
178 |
-
return det_states.numpy()
|
179 |
return torch.zeros_like(indices)
|
180 |
|
181 |
class StoppingCriteriaSub(StoppingCriteria):
|
@@ -240,7 +240,7 @@ def load_and_preprocess_image(image):
|
|
240 |
transforms.ToTensor(),
|
241 |
transforms.Normalize(mean, std)
|
242 |
])
|
243 |
-
image = transform(image).type(torch.bfloat16).
|
244 |
return image
|
245 |
|
246 |
def load_and_preprocess_volume(image):
|
@@ -249,7 +249,7 @@ def load_and_preprocess_volume(image):
|
|
249 |
transform = tio.Compose([
|
250 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
251 |
])
|
252 |
-
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
253 |
return image
|
254 |
|
255 |
def read_image(image_path):
|
@@ -328,14 +328,14 @@ def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_
|
|
328 |
def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
329 |
num_imgs = len(images)
|
330 |
modal = modality.lower()
|
331 |
-
image_tensors = [read_image(img) for img in images]
|
332 |
if modality == 'ct':
|
333 |
time.sleep(2)
|
334 |
else:
|
335 |
time.sleep(1)
|
336 |
image_tensor = torch.cat(image_tensors)
|
337 |
|
338 |
-
with torch.autocast(
|
339 |
with torch.no_grad():
|
340 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
341 |
|
@@ -388,38 +388,6 @@ def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_s
|
|
388 |
|
389 |
return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
|
390 |
|
391 |
-
# my_dict = {}
|
392 |
-
# def gradio_interface(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
393 |
-
# if not images:
|
394 |
-
# return None, "Error: At least one image is required to proceed."
|
395 |
-
# if not prompt or not task or not modality:
|
396 |
-
# return None, "Error: Please provide prompt, select task and modality to proceed."
|
397 |
-
|
398 |
-
# generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
399 |
-
# output_images = []
|
400 |
-
|
401 |
-
# input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
|
402 |
-
# if generated_images is not None:
|
403 |
-
# for generated_image in generated_images:
|
404 |
-
# output_images.append(np.asarray(generated_image).astype(np.uint8))
|
405 |
-
# snapshot = (output_images[0], [])
|
406 |
-
# if seg_mask_2d is not None:
|
407 |
-
# snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
|
408 |
-
# if seg_mask_3d is not None:
|
409 |
-
# snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
|
410 |
-
# else:
|
411 |
-
# output_images = input_images.copy()
|
412 |
-
# snapshot = (output_images[0], [])
|
413 |
-
|
414 |
-
# my_dict['image'] = output_images
|
415 |
-
# my_dict['mask'] = None
|
416 |
-
# if seg_mask_2d is not None:
|
417 |
-
# my_dict['mask'] = seg_mask_2d
|
418 |
-
# if seg_mask_3d is not None:
|
419 |
-
# my_dict['mask'] = seg_mask_3d
|
420 |
-
|
421 |
-
# return output_text, snapshot, gr.update(maximum=len(output_images)-1)
|
422 |
-
|
423 |
def render(x):
|
424 |
if x > len(my_dict['image'])-1:
|
425 |
x = len(my_dict['image'])-1
|
|
|
1 |
import gradio as gr
|
2 |
import argparse
|
3 |
import torch
|
4 |
+
from torch import cuda
|
5 |
import torch.nn.functional as F
|
6 |
import torchvision.transforms.functional as TF
|
7 |
from torchvision import transforms
|
|
|
33 |
args = parser.parse_args()
|
34 |
return args
|
35 |
|
36 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
37 |
# Launch model
|
38 |
args = parse_args()
|
39 |
cfg = Config(args)
|
40 |
|
41 |
model_config = cfg.model_cfg
|
42 |
model_cls = registry.get_model_class(model_config.arch)
|
43 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
|
|
|
44 |
global global_images
|
45 |
global_images = None
|
46 |
|
|
|
146 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
147 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
148 |
seg_probs = F.sigmoid(seg_preds)
|
149 |
+
seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
|
150 |
return seg_mask
|
151 |
else:
|
152 |
return None
|
|
|
165 |
new_img_embeds_list[-1] = last_feats
|
166 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
167 |
seg_probs = F.sigmoid(seg_preds)
|
168 |
+
seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
|
169 |
return seg_mask
|
170 |
|
171 |
def task_det_2d(model, preds, hidden_states):
|
|
|
175 |
if target_states:
|
176 |
target_states = torch.cat(target_states).squeeze()
|
177 |
det_states = model.text_det(target_states).detach().cpu()
|
178 |
+
return det_states.to(torch.float32).numpy()
|
179 |
return torch.zeros_like(indices)
|
180 |
|
181 |
class StoppingCriteriaSub(StoppingCriteria):
|
|
|
240 |
transforms.ToTensor(),
|
241 |
transforms.Normalize(mean, std)
|
242 |
])
|
243 |
+
image = transform(image).type(torch.bfloat16).unsqueeze(0)
|
244 |
return image
|
245 |
|
246 |
def load_and_preprocess_volume(image):
|
|
|
249 |
transform = tio.Compose([
|
250 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
251 |
])
|
252 |
+
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
253 |
return image
|
254 |
|
255 |
def read_image(image_path):
|
|
|
328 |
def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
329 |
num_imgs = len(images)
|
330 |
modal = modality.lower()
|
331 |
+
image_tensors = [read_image(img).to(device) for img in images]
|
332 |
if modality == 'ct':
|
333 |
time.sleep(2)
|
334 |
else:
|
335 |
time.sleep(1)
|
336 |
image_tensor = torch.cat(image_tensors)
|
337 |
|
338 |
+
with torch.autocast(device):
|
339 |
with torch.no_grad():
|
340 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
341 |
|
|
|
388 |
|
389 |
return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
|
390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
def render(x):
|
392 |
if x > len(my_dict['image'])-1:
|
393 |
x = len(my_dict['image'])-1
|
inference.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from utils import *
|
|
|
2 |
|
3 |
# --- Launch Model ---
|
4 |
-
device = 'cuda
|
5 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
6 |
-
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
7 |
-
model.eval()
|
8 |
|
9 |
# --- Define examples ---
|
10 |
examples = [
|
@@ -85,14 +85,14 @@ temperature = 0.1
|
|
85 |
index = 0
|
86 |
demo_ex = examples[index]
|
87 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
88 |
-
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
89 |
print(output_text)
|
90 |
|
91 |
# --- Segment the lesion in the dermatology image ---
|
92 |
index = 6
|
93 |
demo_ex = examples[index]
|
94 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
95 |
-
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
96 |
print(output_text)
|
97 |
print(seg_mask_2d[0].shape) # H, W
|
98 |
|
@@ -100,7 +100,7 @@ print(seg_mask_2d[0].shape) # H, W
|
|
100 |
index = -2
|
101 |
demo_ex = examples[index]
|
102 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
103 |
-
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
104 |
print(output_text)
|
105 |
print(len(seg_mask_3d)) # Number of slices
|
106 |
print(seg_mask_3d[0].shape) # H, W
|
|
|
1 |
from utils import *
|
2 |
+
from torch import cuda
|
3 |
|
4 |
# --- Launch Model ---
|
5 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
6 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
7 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
|
|
|
8 |
|
9 |
# --- Define examples ---
|
10 |
examples = [
|
|
|
85 |
index = 0
|
86 |
demo_ex = examples[index]
|
87 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
88 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
|
89 |
print(output_text)
|
90 |
|
91 |
# --- Segment the lesion in the dermatology image ---
|
92 |
index = 6
|
93 |
demo_ex = examples[index]
|
94 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
95 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
|
96 |
print(output_text)
|
97 |
print(seg_mask_2d[0].shape) # H, W
|
98 |
|
|
|
100 |
index = -2
|
101 |
demo_ex = examples[index]
|
102 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
103 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
|
104 |
print(output_text)
|
105 |
print(len(seg_mask_3d)) # Number of slices
|
106 |
print(seg_mask_3d[0].shape) # H, W
|
utils.py
CHANGED
@@ -133,7 +133,7 @@ def task_seg_2d(model, preds, hidden_states, image):
|
|
133 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
134 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
135 |
seg_probs = F.sigmoid(seg_preds)
|
136 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
137 |
return seg_mask
|
138 |
else:
|
139 |
return None
|
@@ -152,7 +152,7 @@ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
|
|
152 |
new_img_embeds_list[-1] = last_feats
|
153 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
154 |
seg_probs = F.sigmoid(seg_preds)
|
155 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
156 |
return seg_mask
|
157 |
|
158 |
def task_det_2d(model, preds, hidden_states):
|
@@ -227,7 +227,7 @@ def load_and_preprocess_image(image):
|
|
227 |
transforms.ToTensor(),
|
228 |
transforms.Normalize(mean, std)
|
229 |
])
|
230 |
-
image = transform(image).type(torch.bfloat16).
|
231 |
return image
|
232 |
|
233 |
def load_and_preprocess_volume(image):
|
@@ -236,7 +236,7 @@ def load_and_preprocess_volume(image):
|
|
236 |
transform = tio.Compose([
|
237 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
238 |
])
|
239 |
-
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
240 |
return image
|
241 |
|
242 |
def read_image(image_path):
|
@@ -285,7 +285,6 @@ def generate(model, image_path, image, context, modal, task, num_imgs, prompt, n
|
|
285 |
seg_mask = task_seg_2d(model, preds, hidden_states, image)
|
286 |
output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
|
287 |
if sum(preds == model.seg_token_idx_3d):
|
288 |
-
ipdb.set_trace()
|
289 |
seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
|
290 |
output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
|
291 |
if sum(preds == model.det_token_idx):
|
@@ -304,17 +303,17 @@ def generate(model, image_path, image, context, modal, task, num_imgs, prompt, n
|
|
304 |
output_text = 'The main diagnosis is melanoma.'
|
305 |
return output_image, seg_mask_2d, seg_mask_3d, output_text
|
306 |
|
307 |
-
def generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
308 |
num_imgs = len(images)
|
309 |
modal = modality.lower()
|
310 |
-
image_tensors = [read_image(img) for img in images]
|
311 |
if modality == 'ct':
|
312 |
time.sleep(2)
|
313 |
else:
|
314 |
time.sleep(1)
|
315 |
image_tensor = torch.cat(image_tensors)
|
316 |
|
317 |
-
with torch.autocast(
|
318 |
with torch.no_grad():
|
319 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(model, images, image_tensor, context, modal, task, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
320 |
|
|
|
133 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
134 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
135 |
seg_probs = F.sigmoid(seg_preds)
|
136 |
+
seg_mask = seg_probs.to(dtype=torch.float32).cpu().squeeze().numpy() >= 0.5
|
137 |
return seg_mask
|
138 |
else:
|
139 |
return None
|
|
|
152 |
new_img_embeds_list[-1] = last_feats
|
153 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
154 |
seg_probs = F.sigmoid(seg_preds)
|
155 |
+
seg_mask = seg_probs.to(dtype=torch.float32).cpu().squeeze().numpy() >= 0.5
|
156 |
return seg_mask
|
157 |
|
158 |
def task_det_2d(model, preds, hidden_states):
|
|
|
227 |
transforms.ToTensor(),
|
228 |
transforms.Normalize(mean, std)
|
229 |
])
|
230 |
+
image = transform(image).type(torch.bfloat16).unsqueeze(0)
|
231 |
return image
|
232 |
|
233 |
def load_and_preprocess_volume(image):
|
|
|
236 |
transform = tio.Compose([
|
237 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
238 |
])
|
239 |
+
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
240 |
return image
|
241 |
|
242 |
def read_image(image_path):
|
|
|
285 |
seg_mask = task_seg_2d(model, preds, hidden_states, image)
|
286 |
output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
|
287 |
if sum(preds == model.seg_token_idx_3d):
|
|
|
288 |
seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
|
289 |
output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
|
290 |
if sum(preds == model.det_token_idx):
|
|
|
303 |
output_text = 'The main diagnosis is melanoma.'
|
304 |
return output_image, seg_mask_2d, seg_mask_3d, output_text
|
305 |
|
306 |
+
def generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device):
|
307 |
num_imgs = len(images)
|
308 |
modal = modality.lower()
|
309 |
+
image_tensors = [read_image(img).to(device) for img in images]
|
310 |
if modality == 'ct':
|
311 |
time.sleep(2)
|
312 |
else:
|
313 |
time.sleep(1)
|
314 |
image_tensor = torch.cat(image_tensors)
|
315 |
|
316 |
+
with torch.autocast(device):
|
317 |
with torch.no_grad():
|
318 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(model, images, image_tensor, context, modal, task, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
319 |
|