HReynaud's picture
Update app.py
e3044ba
import gradio as gr
import os
from omegaconf import OmegaConf
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen
import torch
import numpy as np
import cv2
from PIL import Image
import torchvision.transforms as T
device = "cuda" if torch.cuda.is_available() else "cpu"
exp_path = "model"
class BetterCenterCrop(T.CenterCrop):
def __call__(self, img):
h = img.shape[-2]
w = img.shape[-1]
dim = min(h, w)
return T.functional.center_crop(img, dim)
class ImageLoader:
def __init__(self, path) -> None:
self.path = path
self.all_files = os.listdir(path)
self.transform = T.Compose([
T.ToTensor(),
BetterCenterCrop((112, 112)),
T.Resize((112, 112)),
])
def get_image(self):
idx = np.random.randint(0, len(self.all_files))
img = Image.open(os.path.join(self.path, self.all_files[idx]))
return img
class Context:
def __init__(self, path, device):
self.path = path
self.config_path = os.path.join(path, "config.yaml")
self.weight_path = os.path.join(path, "merged.pt")
self.config = OmegaConf.load(self.config_path)
self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration)
self.im_load = ImageLoader("echo_images")
unets = []
for i, (k, v) in enumerate(self.config.unets.items()):
unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore
imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen
del self.config.imagen.elucidated
imagen = imagen_klass(
unets = unets,
**OmegaConf.to_container(self.config.imagen), # type: ignore
)
self.trainer = ImagenTrainer(
imagen = imagen,
**self.config.trainer
).to(device)
print("Loading weights from", self.weight_path)
additional_data = self.trainer.load(self.weight_path)
print("Loaded weights from", self.weight_path)
def reshape_image(self, image):
try:
image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy()
return image
except:
return None
def load_random_image(self):
print("Loading random image")
image = self.im_load.get_image()
return image
def generate_video(self, image, lvef, cond_scale):
print("Generating video")
print(f"lvef: {lvef}, cond_scale: {cond_scale}")
image = self.im_load.transform(image).unsqueeze(0)
sample_kwargs = {}
sample_kwargs = {
"text_embeds": torch.tensor([[[lvef/100.0]]]),
"cond_scale": cond_scale,
"cond_images": image,
}
self.trainer.eval()
with torch.no_grad():
video = self.trainer.sample(
batch_size=1,
video_frames=self.config.dataset.num_frames,
**sample_kwargs,
use_tqdm = True,
).detach().cpu() # C x F x H x W
if video.shape[-3:] != (64, 112, 112):
video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False)
video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see
uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app
path = f"tmp/{uid}.mp4"
video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy()
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112))
for i in video:
out.write(i)
out.release()
return path
context = Context(exp_path, device)
with gr.Blocks(css="style.css") as demo:
with gr.Row():
gr.Label("Feature-Conditioned Cascaded Video Diffusion Models for Precise Echocardiogram Synthesis")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column(scale=3, variant="panel"):
text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** [Code is available here](https://github.com/HReynaud/EchoDiffusion) ")
with gr.Column(scale=1, min_width="226"):
image = gr.Image(interactive=True)
with gr.Column(scale=1, min_width="226"):
video = gr.Video(interactive=False)
slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True)
slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True)
with gr.Row():
img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)")
run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀")
image.change(context.reshape_image, inputs=[image], outputs=[image])
img_btn.click(context.load_random_image, inputs=[], outputs=[image])
run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video])
if __name__ == "__main__":
demo.queue()
demo.launch()