callum-canavan's picture
Update interface
785096a
import argparse
from pathlib import Path
import gradio as gr
from icecream import ic
import torch
from diffusers import DiffusionPipeline
from visual_anagrams.views import get_views, VIEW_MAP_NAMES
from visual_anagrams.samplers import sample_stage_1, sample_stage_2
from visual_anagrams.utils import add_args, save_illusion, save_metadata
from visual_anagrams.animate import animate_two_view
stage_1 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-M-v1.0",
variant="fp16",
torch_dtype=torch.float16)
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-M-v1.0",
text_encoder=None,
variant="fp16",
torch_dtype=torch.float16,
)
stage_1.enable_model_cpu_offload()
stage_2.enable_model_cpu_offload()
def generate_content(
style,
prompt_for_original,
prompt_for_transformed,
transformation,
num_inference_steps,
seed,
):
prompts = [f'{style} {p}'.strip() for p in [prompt_for_original, prompt_for_transformed]]
prompt_embeds = [stage_1.encode_prompt(p) for p in prompts]
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds)
views = ['identity', VIEW_MAP_NAMES[transformation]]
views = get_views(views)
generator = torch.manual_seed(seed + 42)
print("Sample stage 1")
image = sample_stage_1(stage_1,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=num_inference_steps,
generator=generator)
print("Sample stage 2")
image = sample_stage_2(stage_2,
image,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=num_inference_steps,
generator=generator)
save_illusion(image, views, Path(""))
output_name = f"illusion.mp4"
size = image.shape[-1]
animate_two_view(
f"sample_{size}.png",
views[1],
prompts[0],
prompts[1],
save_video_path=output_name,
)
return output_name, f"sample_{size}.views.png"
with open("description.txt") as f:
description = f.read()
choices = list(VIEW_MAP_NAMES.keys())
gradio_app = gr.Interface(
fn=generate_content,
title="Multi-View Illusion Diffusion",
inputs=[
gr.Textbox(label="Style", placeholder="an oil painting of"),
gr.Textbox(label="Prompt for original view", placeholder="a dress"),
gr.Textbox(label="Prompt for transformed view", placeholder="an old man"),
gr.Dropdown(label="View transformation", choices=choices, value=choices[0]),
gr.Number(label="Number of diffusion steps", value=50, step=1, minimum=1, maximum=300),
gr.Number(label="Random seed", value=0, step=1, minimum=0, maximum=100000),
],
outputs=[gr.Video(label="Illusion"), gr.Image(label="Before and After")],
description=description,
)
if __name__ == "__main__":
gradio_app.launch() # server_name="0.0.0.0"