import random import shutil import uuid from pathlib import Path import cv2 import gradio as gr import mediapy import mlflow.pytorch import numpy as np import torch from skimage import img_as_ubyte from models.ddim import DDIMSampler import nibabel as nib ffmpeg_path = shutil.which("ffmpeg") mediapy.set_ffmpeg(ffmpeg_path) # Loading model device = torch.device("cpu") vqvae = mlflow.pytorch.load_model( "./trained_models/vae/", map_location=device, ) vqvae.eval() diffusion = mlflow.pytorch.load_model( "./trained_models/ddpm/", map_location=device, ) diffusion.eval() diffusion = diffusion.to(device) vqvae = vqvae.to(device) def sample_fn( gender_radio, age_slider, ventricular_slider, brain_slider, ): print("Sampling brain!") print(f"Gender: {gender_radio}") print(f"Age: {age_slider}") print(f"Ventricular volume: {ventricular_slider}") print(f"Brain volume: {brain_slider}") age_slider = (age_slider - 44) / (82 - 44) cond = torch.Tensor([[gender_radio, age_slider, ventricular_slider, brain_slider]]) latent_shape = [1, 3, 20, 28, 20] cond_crossatten = cond.unsqueeze(1) cond_concat = cond.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) cond_concat = cond_concat.expand(list(cond.shape[0:2]) + list(latent_shape[2:])) conditioning = { "c_concat": [cond_concat.float().to(device)], "c_crossattn": [cond_crossatten.float().to(device)], } ddim = DDIMSampler(diffusion) num_timesteps = 50 latent_vectors, _ = ddim.sample( num_timesteps, conditioning=conditioning, batch_size=1, shape=list(latent_shape[1:]), eta=1.0, ) with torch.no_grad(): x_hat = vqvae.reconstruct_ldm_outputs(latent_vectors).cpu() return x_hat.numpy() def create_videos_and_file( gender_radio, age_slider, ventricular_slider, brain_slider, ): output_dir = Path( f"./outputs/{str(uuid.uuid4())}" ) output_dir.mkdir(exist_ok=True) image_data = sample_fn( gender_radio, age_slider, ventricular_slider, brain_slider, ) image_data = image_data[0, 0, 5:-5, 5:-5, :-15] image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min()) image_data = (image_data * 255).astype(np.uint8) # Write frames to video with mediapy.VideoWriter( f"{str(output_dir)}/brain_axial.mp4", shape=(150, 214), fps=12, crf=18 ) as w: for idx in range(image_data.shape[2]): img = image_data[:, :, idx] img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) frame = img_as_ubyte(img) w.add_image(frame) with mediapy.VideoWriter( f"{str(output_dir)}/brain_sagittal.mp4", shape=(145, 214), fps=12, crf=18 ) as w: for idx in range(image_data.shape[0]): img = np.rot90(image_data[idx, :, :]) img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) frame = img_as_ubyte(img) w.add_image(frame) with mediapy.VideoWriter( f"{str(output_dir)}/brain_coronal.mp4", shape=(145, 150), fps=12, crf=18 ) as w: for idx in range(image_data.shape[1]): img = np.rot90(np.flip(image_data, axis=1)[:, idx, :]) img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) frame = img_as_ubyte(img) w.add_image(frame) # Create file affine = np.array( [ [-1.0, 0.0, 0.0, 96.48149872], [0.0, 1.0, 0.0, -141.47715759], [0.0, 0.0, 1.0, -156.55375671], [0.0, 0.0, 0.0, 1.0], ] ) empty_header = nib.Nifti1Header() sample_nii = nib.Nifti1Image(image_data, affine, empty_header) nib.save(sample_nii, f"{str(output_dir)}/my_brain.nii.gz") # time.sleep(2) return ( f"{str(output_dir)}/brain_axial.mp4", f"{str(output_dir)}/brain_sagittal.mp4", f"{str(output_dir)}/brain_coronal.mp4", f"{str(output_dir)}/my_brain.nii.gz", ) def randomise(): random_age = round(random.uniform(44.0, 82.0), 2) return ( random.choice(["Female", "Male"]), random_age, round(random.uniform(0, 1.0), 2), round(random.uniform(0, 1.0), 2), ) def unrest_randomise(): random_age = round(random.uniform(18.0, 100.0), 2) return ( random.choice([1, 0]), random_age, round(random.uniform(-1.0, 2.0), 2), round(random.uniform(-1.0, 2.0), 2), ) # TEXT title = "Generating Brain Imaging with Diffusion Models" description = """