|
from utils import * |
|
from torch import cuda |
|
|
|
|
|
device = 'cuda' if cuda.is_available() else 'cpu' |
|
model_cls = registry.get_model_class('medomni') |
|
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval() |
|
|
|
|
|
examples = [ |
|
[ |
|
["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"], |
|
"Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.", |
|
"How would you characterize the findings from <img0>?", |
|
"cxr", |
|
"report generation", |
|
], |
|
[ |
|
["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"], |
|
"Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.", |
|
"How would you characterize the findings from <img0>?", |
|
"cxr", |
|
"report generation", |
|
], |
|
[ |
|
["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"], |
|
"Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.", |
|
"How would you characterize the findings from <img0><img1>?", |
|
"cxr", |
|
"report generation", |
|
], |
|
[ |
|
["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"], |
|
"Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.", |
|
"How would you characterize the findings from <img0>?", |
|
"cxr", |
|
"report generation", |
|
], |
|
[ |
|
["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"], |
|
"Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.", |
|
"How would you characterize the findings from <img0>?", |
|
"cxr", |
|
"report generation", |
|
], |
|
[ |
|
["./demo_ex/ISIC_0032258.jpg"], |
|
"Age:70.\nGender:female.\nLocation:back.", |
|
"What is primary diagnosis?", |
|
"derm", |
|
"classification", |
|
], |
|
[ |
|
["./demo_ex/ISIC_0032258.jpg"], |
|
"Age:70.\nGender:female.\nLocation:back.", |
|
"Segment the lesion.", |
|
"derm", |
|
"segmentation", |
|
], |
|
[ |
|
["./demo_ex/Case_01013_0000.nii.gz"], |
|
"", |
|
"Segment the liver.", |
|
"ct", |
|
"segmentation", |
|
], |
|
[ |
|
["./demo_ex/Case_00840_0000.nii.gz"], |
|
"", |
|
"Segment the liver.", |
|
"ct", |
|
"segmentation", |
|
], |
|
] |
|
|
|
num_beams = 1 |
|
do_sample = True |
|
min_length = 1 |
|
top_p = 0.9 |
|
repetition_penalty = 1 |
|
length_penalty = 1 |
|
temperature = 0.1 |
|
|
|
|
|
index = 0 |
|
demo_ex = examples[index] |
|
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4] |
|
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device) |
|
print(output_text) |
|
|
|
|
|
index = 6 |
|
demo_ex = examples[index] |
|
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4] |
|
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device) |
|
print(output_text) |
|
print(seg_mask_2d[0].shape) |
|
|
|
|
|
index = -2 |
|
demo_ex = examples[index] |
|
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4] |
|
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device) |
|
print(output_text) |
|
print(len(seg_mask_3d)) |
|
print(seg_mask_3d[0].shape) |
|
|
|
|