File size: 3,285 Bytes
f161c8a
 
5a72667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d9018e
5a72667
f80665b
5a72667
 
 
 
 
 
 
01e170d
 
 
 
 
 
 
 
 
 
 
5a72667
 
 
 
1d9018e
5a72667
1d9018e
5a72667
1d9018e
5a72667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Archive paper: https://arxiv.org/abs/2404.13000

import os

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import skimage
import monai as mn
import torch

from io_utils import LoadImageD

# Gradio helper functions

current_img = None
live_preds = None

def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
    
    global current_img
    
    angles = (xt, yt, zt)
    out_img_path = f'data/cached_outputs/{os.path.basename(img_path)[:-4]}_{angles}.png'
    out_img = skimage.io.imread(out_img_path)
    if not add_bone_cmap:
        return out_img
    cmap = plt.get_cmap('bone')
    out_img = cmap(out_img)
    out_img = (out_img[..., :3] * 255).astype(np.uint8)
    current_img = out_img
    return out_img
    
css_style = "./style.css"
callback = gr.CSVLogger()
with gr.Blocks(css=css_style, title="RadRotator") as app:
    gr.HTML("RadRotator: 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
    gr.HTML("Developed by:<br>Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles<br><a href='https://pouriarouzrokh.github.io/RadRotator'>[Our website]</a>, <a href='https://arxiv.org/abs/2404.13000'>[arXiv Paper]</a>", elem_classes="note")
    gr.HTML("Note: The demo operates on a CPU, and since diffusion models require more computational capacity to function, all predictions are precomputed.", elem_classes="note")
    
    with gr.TabItem("Demo"):
        with gr.Row():
            input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
            output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
        with gr.Row():
            with gr.Column(scale=0.25):
                pass
            with gr.Column(scale=1):
                gr.Examples(
                    examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], 
                    inputs = [input_img],
                    label = "Xray Examples",
                    elem_id='examples',
                )
            with gr.Column(scale=0.25):
                pass
        with gr.Row():
            gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
        with gr.Row():
            with gr.Column(scale=1):
                xt = gr.Slider(label='x axis (medial/lateral rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
            with gr.Column(scale=1):
                yt = gr.Slider(label='y axis (inlet/outlet rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
            with gr.Column(scale=1):
                zt = gr.Slider(label='z axis (plane rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
        with gr.Row():
            rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
        rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
        
try:
    app.close()
    gr.close_all()
except:
    pass

demo = app.launch(
    max_threads=4,
    share=True,
    inline=False,
    show_api=False,
    show_error=False,
)