hyzhou commited on
Commit
908ff76
1 Parent(s): 3aa9d03

support cpu

Browse files
Files changed (5) hide show
  1. README.md +3 -3
  2. __pycache__/utils.cpython-39.pyc +0 -0
  3. demo.py +10 -42
  4. inference.py +6 -6
  5. utils.py +7 -8
README.md CHANGED
@@ -22,12 +22,12 @@ conda activate medversa
22
  ## Inference
23
  ``` python
24
  from utils import *
 
25
 
26
  # --- Launch Model ---
27
- device = 'cuda:0'
28
  model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
29
- model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
30
- model.eval()
31
 
32
  # --- Define examples ---
33
  examples = [
 
22
  ## Inference
23
  ``` python
24
  from utils import *
25
+ from torch import cuda
26
 
27
  # --- Launch Model ---
28
+ device = 'cuda' if cuda.is_available() else 'cpu'
29
  model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
30
+ model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
 
31
 
32
  # --- Define examples ---
33
  examples = [
__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/__pycache__/utils.cpython-39.pyc and b/__pycache__/utils.cpython-39.pyc differ
 
demo.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import argparse
3
  import torch
 
4
  import torch.nn.functional as F
5
  import torchvision.transforms.functional as TF
6
  from torchvision import transforms
@@ -32,15 +33,14 @@ def parse_args():
32
  args = parser.parse_args()
33
  return args
34
 
35
- device = 'cuda:0'
36
  # Launch model
37
  args = parse_args()
38
  cfg = Config(args)
39
 
40
  model_config = cfg.model_cfg
41
  model_cls = registry.get_model_class(model_config.arch)
42
- model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
43
- model.eval()
44
  global global_images
45
  global_images = None
46
 
@@ -146,7 +146,7 @@ def task_seg_2d(model, preds, hidden_states, image):
146
  seg_feats = model.model_seg_2d.decoder(*feats)
147
  seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
148
  seg_probs = F.sigmoid(seg_preds)
149
- seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
150
  return seg_mask
151
  else:
152
  return None
@@ -165,7 +165,7 @@ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
165
  new_img_embeds_list[-1] = last_feats
166
  seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
167
  seg_probs = F.sigmoid(seg_preds)
168
- seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
169
  return seg_mask
170
 
171
  def task_det_2d(model, preds, hidden_states):
@@ -175,7 +175,7 @@ def task_det_2d(model, preds, hidden_states):
175
  if target_states:
176
  target_states = torch.cat(target_states).squeeze()
177
  det_states = model.text_det(target_states).detach().cpu()
178
- return det_states.numpy()
179
  return torch.zeros_like(indices)
180
 
181
  class StoppingCriteriaSub(StoppingCriteria):
@@ -240,7 +240,7 @@ def load_and_preprocess_image(image):
240
  transforms.ToTensor(),
241
  transforms.Normalize(mean, std)
242
  ])
243
- image = transform(image).type(torch.bfloat16).cuda().unsqueeze(0)
244
  return image
245
 
246
  def load_and_preprocess_volume(image):
@@ -249,7 +249,7 @@ def load_and_preprocess_volume(image):
249
  transform = tio.Compose([
250
  tio.ZNormalization(masking_method=tio.ZNormalization.mean),
251
  ])
252
- image = transform(image.unsqueeze(0)).type(torch.bfloat16).cuda()
253
  return image
254
 
255
  def read_image(image_path):
@@ -328,14 +328,14 @@ def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_
328
  def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
329
  num_imgs = len(images)
330
  modal = modality.lower()
331
- image_tensors = [read_image(img) for img in images]
332
  if modality == 'ct':
333
  time.sleep(2)
334
  else:
335
  time.sleep(1)
336
  image_tensor = torch.cat(image_tensors)
337
 
338
- with torch.autocast("cuda"):
339
  with torch.no_grad():
340
  generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
341
 
@@ -388,38 +388,6 @@ def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_s
388
 
389
  return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
390
 
391
- # my_dict = {}
392
- # def gradio_interface(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
393
- # if not images:
394
- # return None, "Error: At least one image is required to proceed."
395
- # if not prompt or not task or not modality:
396
- # return None, "Error: Please provide prompt, select task and modality to proceed."
397
-
398
- # generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
399
- # output_images = []
400
-
401
- # input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
402
- # if generated_images is not None:
403
- # for generated_image in generated_images:
404
- # output_images.append(np.asarray(generated_image).astype(np.uint8))
405
- # snapshot = (output_images[0], [])
406
- # if seg_mask_2d is not None:
407
- # snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
408
- # if seg_mask_3d is not None:
409
- # snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
410
- # else:
411
- # output_images = input_images.copy()
412
- # snapshot = (output_images[0], [])
413
-
414
- # my_dict['image'] = output_images
415
- # my_dict['mask'] = None
416
- # if seg_mask_2d is not None:
417
- # my_dict['mask'] = seg_mask_2d
418
- # if seg_mask_3d is not None:
419
- # my_dict['mask'] = seg_mask_3d
420
-
421
- # return output_text, snapshot, gr.update(maximum=len(output_images)-1)
422
-
423
  def render(x):
424
  if x > len(my_dict['image'])-1:
425
  x = len(my_dict['image'])-1
 
1
  import gradio as gr
2
  import argparse
3
  import torch
4
+ from torch import cuda
5
  import torch.nn.functional as F
6
  import torchvision.transforms.functional as TF
7
  from torchvision import transforms
 
33
  args = parser.parse_args()
34
  return args
35
 
36
+ device = 'cuda' if cuda.is_available() else 'cpu'
37
  # Launch model
38
  args = parse_args()
39
  cfg = Config(args)
40
 
41
  model_config = cfg.model_cfg
42
  model_cls = registry.get_model_class(model_config.arch)
43
+ model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
 
44
  global global_images
45
  global_images = None
46
 
 
146
  seg_feats = model.model_seg_2d.decoder(*feats)
147
  seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
148
  seg_probs = F.sigmoid(seg_preds)
149
+ seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
150
  return seg_mask
151
  else:
152
  return None
 
165
  new_img_embeds_list[-1] = last_feats
166
  seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
167
  seg_probs = F.sigmoid(seg_preds)
168
+ seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
169
  return seg_mask
170
 
171
  def task_det_2d(model, preds, hidden_states):
 
175
  if target_states:
176
  target_states = torch.cat(target_states).squeeze()
177
  det_states = model.text_det(target_states).detach().cpu()
178
+ return det_states.to(torch.float32).numpy()
179
  return torch.zeros_like(indices)
180
 
181
  class StoppingCriteriaSub(StoppingCriteria):
 
240
  transforms.ToTensor(),
241
  transforms.Normalize(mean, std)
242
  ])
243
+ image = transform(image).type(torch.bfloat16).unsqueeze(0)
244
  return image
245
 
246
  def load_and_preprocess_volume(image):
 
249
  transform = tio.Compose([
250
  tio.ZNormalization(masking_method=tio.ZNormalization.mean),
251
  ])
252
+ image = transform(image.unsqueeze(0)).type(torch.bfloat16)
253
  return image
254
 
255
  def read_image(image_path):
 
328
  def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
329
  num_imgs = len(images)
330
  modal = modality.lower()
331
+ image_tensors = [read_image(img).to(device) for img in images]
332
  if modality == 'ct':
333
  time.sleep(2)
334
  else:
335
  time.sleep(1)
336
  image_tensor = torch.cat(image_tensors)
337
 
338
+ with torch.autocast(device):
339
  with torch.no_grad():
340
  generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
341
 
 
388
 
389
  return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  def render(x):
392
  if x > len(my_dict['image'])-1:
393
  x = len(my_dict['image'])-1
inference.py CHANGED
@@ -1,10 +1,10 @@
1
  from utils import *
 
2
 
3
  # --- Launch Model ---
4
- device = 'cuda:0'
5
  model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
6
- model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
7
- model.eval()
8
 
9
  # --- Define examples ---
10
  examples = [
@@ -85,14 +85,14 @@ temperature = 0.1
85
  index = 0
86
  demo_ex = examples[index]
87
  images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
88
- seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
89
  print(output_text)
90
 
91
  # --- Segment the lesion in the dermatology image ---
92
  index = 6
93
  demo_ex = examples[index]
94
  images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
95
- seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
96
  print(output_text)
97
  print(seg_mask_2d[0].shape) # H, W
98
 
@@ -100,7 +100,7 @@ print(seg_mask_2d[0].shape) # H, W
100
  index = -2
101
  demo_ex = examples[index]
102
  images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
103
- seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
104
  print(output_text)
105
  print(len(seg_mask_3d)) # Number of slices
106
  print(seg_mask_3d[0].shape) # H, W
 
1
  from utils import *
2
+ from torch import cuda
3
 
4
  # --- Launch Model ---
5
+ device = 'cuda' if cuda.is_available() else 'cpu'
6
  model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
7
+ model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
 
8
 
9
  # --- Define examples ---
10
  examples = [
 
85
  index = 0
86
  demo_ex = examples[index]
87
  images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
88
+ seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
89
  print(output_text)
90
 
91
  # --- Segment the lesion in the dermatology image ---
92
  index = 6
93
  demo_ex = examples[index]
94
  images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
95
+ seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
96
  print(output_text)
97
  print(seg_mask_2d[0].shape) # H, W
98
 
 
100
  index = -2
101
  demo_ex = examples[index]
102
  images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
103
+ seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
104
  print(output_text)
105
  print(len(seg_mask_3d)) # Number of slices
106
  print(seg_mask_3d[0].shape) # H, W
utils.py CHANGED
@@ -133,7 +133,7 @@ def task_seg_2d(model, preds, hidden_states, image):
133
  seg_feats = model.model_seg_2d.decoder(*feats)
134
  seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
135
  seg_probs = F.sigmoid(seg_preds)
136
- seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
137
  return seg_mask
138
  else:
139
  return None
@@ -152,7 +152,7 @@ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
152
  new_img_embeds_list[-1] = last_feats
153
  seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
154
  seg_probs = F.sigmoid(seg_preds)
155
- seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
156
  return seg_mask
157
 
158
  def task_det_2d(model, preds, hidden_states):
@@ -227,7 +227,7 @@ def load_and_preprocess_image(image):
227
  transforms.ToTensor(),
228
  transforms.Normalize(mean, std)
229
  ])
230
- image = transform(image).type(torch.bfloat16).cuda().unsqueeze(0)
231
  return image
232
 
233
  def load_and_preprocess_volume(image):
@@ -236,7 +236,7 @@ def load_and_preprocess_volume(image):
236
  transform = tio.Compose([
237
  tio.ZNormalization(masking_method=tio.ZNormalization.mean),
238
  ])
239
- image = transform(image.unsqueeze(0)).type(torch.bfloat16).cuda()
240
  return image
241
 
242
  def read_image(image_path):
@@ -285,7 +285,6 @@ def generate(model, image_path, image, context, modal, task, num_imgs, prompt, n
285
  seg_mask = task_seg_2d(model, preds, hidden_states, image)
286
  output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
287
  if sum(preds == model.seg_token_idx_3d):
288
- ipdb.set_trace()
289
  seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
290
  output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
291
  if sum(preds == model.det_token_idx):
@@ -304,17 +303,17 @@ def generate(model, image_path, image, context, modal, task, num_imgs, prompt, n
304
  output_text = 'The main diagnosis is melanoma.'
305
  return output_image, seg_mask_2d, seg_mask_3d, output_text
306
 
307
- def generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
308
  num_imgs = len(images)
309
  modal = modality.lower()
310
- image_tensors = [read_image(img) for img in images]
311
  if modality == 'ct':
312
  time.sleep(2)
313
  else:
314
  time.sleep(1)
315
  image_tensor = torch.cat(image_tensors)
316
 
317
- with torch.autocast("cuda"):
318
  with torch.no_grad():
319
  generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(model, images, image_tensor, context, modal, task, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
320
 
 
133
  seg_feats = model.model_seg_2d.decoder(*feats)
134
  seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
135
  seg_probs = F.sigmoid(seg_preds)
136
+ seg_mask = seg_probs.to(dtype=torch.float32).cpu().squeeze().numpy() >= 0.5
137
  return seg_mask
138
  else:
139
  return None
 
152
  new_img_embeds_list[-1] = last_feats
153
  seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
154
  seg_probs = F.sigmoid(seg_preds)
155
+ seg_mask = seg_probs.to(dtype=torch.float32).cpu().squeeze().numpy() >= 0.5
156
  return seg_mask
157
 
158
  def task_det_2d(model, preds, hidden_states):
 
227
  transforms.ToTensor(),
228
  transforms.Normalize(mean, std)
229
  ])
230
+ image = transform(image).type(torch.bfloat16).unsqueeze(0)
231
  return image
232
 
233
  def load_and_preprocess_volume(image):
 
236
  transform = tio.Compose([
237
  tio.ZNormalization(masking_method=tio.ZNormalization.mean),
238
  ])
239
+ image = transform(image.unsqueeze(0)).type(torch.bfloat16)
240
  return image
241
 
242
  def read_image(image_path):
 
285
  seg_mask = task_seg_2d(model, preds, hidden_states, image)
286
  output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
287
  if sum(preds == model.seg_token_idx_3d):
 
288
  seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
289
  output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
290
  if sum(preds == model.det_token_idx):
 
303
  output_text = 'The main diagnosis is melanoma.'
304
  return output_image, seg_mask_2d, seg_mask_3d, output_text
305
 
306
+ def generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device):
307
  num_imgs = len(images)
308
  modal = modality.lower()
309
+ image_tensors = [read_image(img).to(device) for img in images]
310
  if modality == 'ct':
311
  time.sleep(2)
312
  else:
313
  time.sleep(1)
314
  image_tensor = torch.cat(image_tensors)
315
 
316
+ with torch.autocast(device):
317
  with torch.no_grad():
318
  generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(model, images, image_tensor, context, modal, task, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
319