Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from diffusers.models import UNet2DModel | |
from huggingface_hub import hf_hub_download | |
from oadg.sampling import sample, make_conditional_paths_and_realization, initialize_empty_realizations_and_paths | |
from oadg.sampling import evaluate_entropy | |
image_size = 64 | |
batch_size = 1 | |
device = 'cpu' | |
path = hf_hub_download(repo_id="porestar/oadg_channels_64", filename="model.pt") | |
model = UNet2DModel( | |
sample_size=64, | |
in_channels=2, | |
out_channels=2, | |
layers_per_block=2, | |
block_out_channels=(64, 64, 128, 128), | |
down_block_types=( | |
"DownBlock2D", | |
"DownBlock2D", | |
"AttnDownBlock2D", | |
"DownBlock2D", | |
), | |
up_block_types=( | |
"UpBlock2D", | |
"AttnUpBlock2D", | |
"UpBlock2D", | |
"UpBlock2D", | |
), | |
) | |
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) | |
model = model.to(device) | |
def sample_image(img): | |
if img is None: | |
idx_start, random_paths, realization = initialize_empty_realizations_and_paths(batch_size, image_size, image_size, device=device) | |
else: | |
img = (img > 0).astype(int) | |
idx_start, random_paths, realization = make_conditional_paths_and_realization(img, batch_size=batch_size, device=device) | |
img = sample(model, batch_size=batch_size, image_size=image_size, | |
realization=realization, idx_start=idx_start, random_paths=random_paths, device=device) | |
img = img.reshape(image_size, image_size) * 255 | |
entropy = evaluate_entropy(model, batch_size=batch_size, image_size=image_size, | |
realization=realization, idx_start=idx_start, random_paths=random_paths, device=device) | |
entropy = (entropy.reshape(image_size, image_size) * 255).astype(int) | |
return entropy, img | |
gr.Text() | |
img = gr.Image(image_mode="L", source="canvas", shape=(image_size, image_size), invert_colors=True, label="Drawing Canvas") | |
out_realization = gr.Image(image_mode="L", shape=(image_size, image_size), invert_colors=True, label="Sample Realization") | |
out_entropy = gr.Image(image_mode="L", shape=(image_size, image_size), invert_colors=True, label="Entropy of Drawn Data") | |
demo = gr.Interface(fn=sample_image, inputs=img, outputs=[out_entropy, out_realization], | |
title="Order Agnostic Autoregressive Diffusion Channels Demo", | |
description="""Sample conditional or unconditional images by drawing into the canvas. | |
Outputs a random sampled realization and predicted entropy under the trained model for the conditioning data.""") | |
demo.launch() | |