from utils import * from torch import cuda # --- Launch Model --- device = 'cuda' if cuda.is_available() else 'cpu' model_cls = registry.get_model_class('medomni') # medomni is the architecture name :) model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval() # --- Define examples --- examples = [ [ ["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"], "Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.", "How would you characterize the findings from ?", "cxr", "report generation", ], [ ["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"], "Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.", "How would you characterize the findings from ?", "cxr", "report generation", ], [ ["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"], "Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.", "How would you characterize the findings from ?", "cxr", "report generation", ], [ ["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"], "Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.", "How would you characterize the findings from ?", "cxr", "report generation", ], [ ["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"], "Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.", "How would you characterize the findings from ?", "cxr", "report generation", ], [ ["./demo_ex/ISIC_0032258.jpg"], "Age:70.\nGender:female.\nLocation:back.", "What is primary diagnosis?", "derm", "classification", ], [ ["./demo_ex/ISIC_0032258.jpg"], "Age:70.\nGender:female.\nLocation:back.", "Segment the lesion.", "derm", "segmentation", ], [ ["./demo_ex/Case_01013_0000.nii.gz"], "", "Segment the liver.", "ct", "segmentation", ], [ ["./demo_ex/Case_00840_0000.nii.gz"], "", "Segment the liver.", "ct", "segmentation", ], ] # --- Define hyperparams --- num_beams = 1 do_sample = True min_length = 1 top_p = 0.9 repetition_penalty = 1 length_penalty = 1 temperature = 0.1 # --- Generate a report for a chest X-ray image --- index = 0 demo_ex = examples[index] images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4] seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device) print(output_text) # --- Segment the lesion in the dermatology image --- index = 6 demo_ex = examples[index] images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4] seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device) print(output_text) print(seg_mask_2d[0].shape) # H, W # --- Segment the liver in the abdomen CT scan --- index = -2 demo_ex = examples[index] images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4] seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device) print(output_text) print(len(seg_mask_3d)) # Number of slices print(seg_mask_3d[0].shape) # H, W