File size: 4,088 Bytes
e6d2f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
import os
import glob
import shutil
import torch
import argparse
import mediapy
import cv2
import numpy as np
import gradio as gr 
from skimage import color, img_as_ubyte
from monai import transforms, data

os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc")
sys.path.append("pmrc/SwinUNETR/BTCV")
from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig


ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)

# Load model
model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
model.eval()

# Pull files from github
input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')
input_files = dict((f.split('/')[-1], f) for f in input_files)

# Load and process dicom with monai transforms
test_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.AddChanneld(keys=["image"]),
        transforms.Spacingd(keys="image",
                            pixdim=(1.5, 1.5, 2.0),
                            mode="bilinear"),
        transforms.ScaleIntensityRanged(keys=["image"],
                                        a_min=-175.0,
                                        a_max=250.0,
                                        b_min=0.0,
                                        b_max=1.0,
                                        clip=True),
        # transforms.Resized(keys=["image"], spatial_size = (256,256,-1)), 
        transforms.ToTensord(keys=["image"]),
    ])

# Create Data Loader
def create_dl(test_files):
    ds = test_transform(test_files)
    loader = data.DataLoader(ds,
                             batch_size=1,
                             shuffle=False)
    return loader

# Inference and video generation
def generate_dicom_video(selected_file, n_frames):
    
    # Data processor
    test_file = input_files[selected_file]
    test_files = [{'image': test_file}]
    dl = create_dl(test_files)
    batch = next(iter(dl))
    
    # Select dicom slices
    tst_inputs = batch["image"]
    tst_inputs = tst_inputs[:,:,:,:,-n_frames:]

    # Inference    
    with torch.no_grad():
        outputs = model(tst_inputs,
                            (96,96,96),
                            8,
                            overlap=0.5,
                            mode="gaussian")
    tst_outputs = torch.softmax(outputs.logits, 1)
    tst_outputs = torch.argmax(tst_outputs, axis=1)
    
    # Write frames to video
    for inp, outp in zip(tst_inputs, tst_outputs):
        frames = []
        for idx in range(inp.shape[-1]):
            # Segmentation
            seg = outp[:,:,idx].numpy().astype(np.uint8)
            # Input dicom frame
            img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8)
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
            frame = color.label2rgb(seg,img, bg_label = 0)
            frame = img_as_ubyte(frame)
            frame = np.concatenate((img, frame), 1)
            frames.append(frame)
        mediapy.write_video("dicom.mp4", frames, fps=4)
    
    return 'dicom.mp4'


theme = 'dark-peach'
with gr.Blocks(theme=theme) as demo:

    gr.Markdown('''<center><h1>SwinUnetr BTCV</h1></center>
	This is a Gradio Blocks app of the winning transformer in the Beyond the Cranial Vault (BTCV) Segmentation Challenge, <a href="https://github.com/darraghdog/Project-MONAI-research-contributions/tree/main/SwinUNETR/BTCV">SwinUnetr</a> (tiny version).
	''')
    selected_dicom_key = gr.inputs.Dropdown(
            choices=sorted(input_files),
            type="value",
            label="Select a dicom file")
    n_frames = gr.Slider(1, 100, value=32, label="Choose the number of dicom slices to process", step = 1)
    button_gen_video = gr.Button("Generate Video")
    output_interpolation = gr.Video(label="Generated Video")
    button_gen_video.click(fn=generate_dicom_video, 
                           inputs=[selected_dicom_key, n_frames], 
                           outputs=output_interpolation)

demo.launch(debug=True, enable_queue=True)