Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import skimage | |
import monai as mn | |
import torch | |
from io_utils import LoadImageD | |
# Gradio helper functions | |
current_img = None | |
live_preds = None | |
def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False): | |
global current_img | |
angles = (xt, yt, zt) | |
out_img_path = f'data/cached_outputs/{os.path.basename(img_path)[:-4]}_{angles}.png' | |
out_img = skimage.io.imread(out_img_path) | |
if not add_bone_cmap: | |
return out_img | |
cmap = plt.get_cmap('bone') | |
out_img = cmap(out_img) | |
out_img = (out_img[..., :3] * 255).astype(np.uint8) | |
current_img = out_img | |
return out_img | |
css_style = "./style.css" | |
callback = gr.CSVLogger() | |
with gr.Blocks(css=css_style) as app: | |
gr.HTML("RadRotator: 3D Rotation of Radiographs with Diffusion Models", elem_classes="title") | |
gr.HTML("Developed by:<br>Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="subtitle") | |
gr.HTML("Note: The demo operates on a CPU, and since diffusion models require more computational capacity to function, all predictions are precomputed.", elem_classes="note") | |
with gr.TabItem("Demo"): | |
with gr.Row(): | |
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs') | |
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs') | |
with gr.Row(): | |
with gr.Column(scale=0.25): | |
pass | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], | |
inputs = [input_img], | |
label = "Xray Examples", | |
elem_id='examples', | |
) | |
with gr.Column(scale=0.25): | |
pass | |
with gr.Row(): | |
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) | |
with gr.Column(scale=1): | |
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) | |
with gr.Column(scale=1): | |
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) | |
with gr.Row(): | |
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') | |
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) | |
try: | |
app.close() | |
gr.close_all() | |
except: | |
pass | |
demo = app.launch( | |
max_threads=4, | |
share=True, | |
inline=False, | |
show_api=False, | |
show_error=False, | |
) |