MedVersa_Internal / inference.py
hyzhou's picture
support cpu
908ff76
raw
history blame
4.1 kB
from utils import *
from torch import cuda
# --- Launch Model ---
device = 'cuda' if cuda.is_available() else 'cpu'
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
# --- Define examples ---
examples = [
[
["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
"Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
"How would you characterize the findings from <img0>?",
"cxr",
"report generation",
],
[
["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"],
"Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.",
"How would you characterize the findings from <img0>?",
"cxr",
"report generation",
],
[
["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"],
"Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.",
"How would you characterize the findings from <img0><img1>?",
"cxr",
"report generation",
],
[
["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"],
"Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.",
"How would you characterize the findings from <img0>?",
"cxr",
"report generation",
],
[
["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"],
"Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.",
"How would you characterize the findings from <img0>?",
"cxr",
"report generation",
],
[
["./demo_ex/ISIC_0032258.jpg"],
"Age:70.\nGender:female.\nLocation:back.",
"What is primary diagnosis?",
"derm",
"classification",
],
[
["./demo_ex/ISIC_0032258.jpg"],
"Age:70.\nGender:female.\nLocation:back.",
"Segment the lesion.",
"derm",
"segmentation",
],
[
["./demo_ex/Case_01013_0000.nii.gz"],
"",
"Segment the liver.",
"ct",
"segmentation",
],
[
["./demo_ex/Case_00840_0000.nii.gz"],
"",
"Segment the liver.",
"ct",
"segmentation",
],
]
# --- Define hyperparams ---
num_beams = 1
do_sample = True
min_length = 1
top_p = 0.9
repetition_penalty = 1
length_penalty = 1
temperature = 0.1
# --- Generate a report for a chest X-ray image ---
index = 0
demo_ex = examples[index]
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
print(output_text)
# --- Segment the lesion in the dermatology image ---
index = 6
demo_ex = examples[index]
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
print(output_text)
print(seg_mask_2d[0].shape) # H, W
# --- Segment the liver in the abdomen CT scan ---
index = -2
demo_ex = examples[index]
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
print(output_text)
print(len(seg_mask_3d)) # Number of slices
print(seg_mask_3d[0].shape) # H, W