Spaces:
Build error
Build error
import os | |
import json | |
import torch | |
import random | |
import gradio as gr | |
from glob import glob | |
from omegaconf import OmegaConf | |
from datetime import datetime | |
from safetensors import safe_open | |
from diffusers import AutoencoderKL | |
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from animatediff.models.unet import UNet3DConditionModel | |
from animatediff.pipelines.pipeline_animation import AnimationPipeline | |
from animatediff.utils.util import save_videos_grid | |
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint | |
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora | |
sample_idx = 0 | |
scheduler_dict = { | |
"Euler": EulerDiscreteScheduler, | |
"PNDM": PNDMScheduler, | |
"DDIM": DDIMScheduler, | |
} | |
css = """ | |
.toolbutton { | |
margin-buttom: 0em 0em 0em 0em; | |
max-width: 2.5em; | |
min-width: 2.5em !important; | |
height: 2.5em; | |
} | |
""" | |
class AnimateController: | |
def __init__(self): | |
# config dirs | |
self.basedir = os.getcwd() | |
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") | |
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") | |
self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") | |
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) | |
self.savedir_sample = os.path.join(self.savedir, "sample") | |
os.makedirs(self.savedir, exist_ok=True) | |
self.stable_diffusion_list = [] | |
self.motion_module_list = [] | |
self.personalized_model_list = [] | |
self.refresh_stable_diffusion() | |
self.refresh_motion_module() | |
self.refresh_personalized_model() | |
# config models | |
self.tokenizer = None | |
self.text_encoder = None | |
self.vae = None | |
self.unet = None | |
self.pipeline = None | |
self.lora_model_state_dict = {} | |
self.inference_config = OmegaConf.load("configs/inference/inference.yaml") | |
def refresh_stable_diffusion(self): | |
self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/")) | |
def refresh_motion_module(self): | |
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) | |
self.motion_module_list = [os.path.basename(p) for p in motion_module_list] | |
def refresh_personalized_model(self): | |
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) | |
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] | |
def update_stable_diffusion(self, stable_diffusion_dropdown): | |
self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer") | |
self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda() | |
self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda() | |
self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() | |
return gr.Dropdown.update() | |
def update_motion_module(self, motion_module_dropdown): | |
if self.unet is None: | |
gr.Info(f"Please select a pretrained model path.") | |
return gr.Dropdown.update(value=None) | |
else: | |
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) | |
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") | |
missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) | |
assert len(unexpected) == 0 | |
return gr.Dropdown.update() | |
def update_base_model(self, base_model_dropdown): | |
if self.unet is None: | |
gr.Info(f"Please select a pretrained model path.") | |
return gr.Dropdown.update(value=None) | |
else: | |
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) | |
base_model_state_dict = {} | |
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
base_model_state_dict[key] = f.get_tensor(key) | |
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config) | |
self.vae.load_state_dict(converted_vae_checkpoint) | |
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config) | |
self.unet.load_state_dict(converted_unet_checkpoint, strict=False) | |
self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) | |
return gr.Dropdown.update() | |
def update_lora_model(self, lora_model_dropdown): | |
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) | |
self.lora_model_state_dict = {} | |
if lora_model_dropdown == "none": pass | |
else: | |
with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
self.lora_model_state_dict[key] = f.get_tensor(key) | |
return gr.Dropdown.update() | |
def animate( | |
self, | |
stable_diffusion_dropdown, | |
motion_module_dropdown, | |
base_model_dropdown, | |
lora_alpha_slider, | |
prompt_textbox, | |
negative_prompt_textbox, | |
sampler_dropdown, | |
sample_step_slider, | |
width_slider, | |
length_slider, | |
height_slider, | |
cfg_scale_slider, | |
seed_textbox | |
): | |
if self.unet is None: | |
raise gr.Error(f"Please select a pretrained model path.") | |
if motion_module_dropdown == "": | |
raise gr.Error(f"Please select a motion module.") | |
if base_model_dropdown == "": | |
raise gr.Error(f"Please select a base DreamBooth model.") | |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() | |
pipeline = AnimationPipeline( | |
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, | |
scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) | |
).to("cuda") | |
if self.lora_model_state_dict != {}: | |
pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) | |
pipeline.to("cuda") | |
if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) | |
else: torch.seed() | |
seed = torch.initial_seed() | |
sample = pipeline( | |
prompt_textbox, | |
negative_prompt = negative_prompt_textbox, | |
num_inference_steps = sample_step_slider, | |
guidance_scale = cfg_scale_slider, | |
width = width_slider, | |
height = height_slider, | |
video_length = length_slider, | |
).videos | |
save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") | |
save_videos_grid(sample, save_sample_path) | |
sample_config = { | |
"prompt": prompt_textbox, | |
"n_prompt": negative_prompt_textbox, | |
"sampler": sampler_dropdown, | |
"num_inference_steps": sample_step_slider, | |
"guidance_scale": cfg_scale_slider, | |
"width": width_slider, | |
"height": height_slider, | |
"video_length": length_slider, | |
"seed": seed | |
} | |
json_str = json.dumps(sample_config, indent=4) | |
with open(os.path.join(self.savedir, "logs.json"), "a") as f: | |
f.write(json_str) | |
f.write("\n\n") | |
return gr.Video.update(value=save_sample_path) | |
controller = AnimateController() | |
def ui(): | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) | |
Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br> | |
[Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/) | |
""" | |
) | |
with gr.Column(variant="panel"): | |
gr.Markdown( | |
""" | |
### 1. Model checkpoints (select pretrained model path first). | |
""" | |
) | |
with gr.Row(): | |
stable_diffusion_dropdown = gr.Dropdown( | |
label="Pretrained Model Path", | |
choices=controller.stable_diffusion_list, | |
interactive=True, | |
) | |
stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) | |
stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
def update_stable_diffusion(): | |
controller.refresh_stable_diffusion() | |
return gr.Dropdown.update(choices=controller.stable_diffusion_list) | |
stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown]) | |
with gr.Row(): | |
motion_module_dropdown = gr.Dropdown( | |
label="Select motion module", | |
choices=controller.motion_module_list, | |
interactive=True, | |
) | |
motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown]) | |
motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
def update_motion_module(): | |
controller.refresh_motion_module() | |
return gr.Dropdown.update(choices=controller.motion_module_list) | |
motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown]) | |
base_model_dropdown = gr.Dropdown( | |
label="Select base Dreambooth model (required)", | |
choices=controller.personalized_model_list, | |
interactive=True, | |
) | |
base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown]) | |
lora_model_dropdown = gr.Dropdown( | |
label="Select LoRA model (optional)", | |
choices=["none"] + controller.personalized_model_list, | |
value="none", | |
interactive=True, | |
) | |
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown]) | |
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) | |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
def update_personalized_model(): | |
controller.refresh_personalized_model() | |
return [ | |
gr.Dropdown.update(choices=controller.personalized_model_list), | |
gr.Dropdown.update(choices=["none"] + controller.personalized_model_list) | |
] | |
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) | |
with gr.Column(variant="panel"): | |
gr.Markdown( | |
""" | |
### 2. Configs for AnimateDiff. | |
""" | |
) | |
prompt_textbox = gr.Textbox(label="Prompt", lines=2) | |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2) | |
with gr.Row().style(equal_height=False): | |
with gr.Column(): | |
with gr.Row(): | |
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) | |
sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1) | |
width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) | |
height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) | |
length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1) | |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20) | |
with gr.Row(): | |
seed_textbox = gr.Textbox(label="Seed", value=-1) | |
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") | |
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) | |
generate_button = gr.Button(value="Generate", variant='primary') | |
result_video = gr.Video(label="Generated Animation", interactive=False) | |
generate_button.click( | |
fn=controller.animate, | |
inputs=[ | |
stable_diffusion_dropdown, | |
motion_module_dropdown, | |
base_model_dropdown, | |
lora_alpha_slider, | |
prompt_textbox, | |
negative_prompt_textbox, | |
sampler_dropdown, | |
sample_step_slider, | |
width_slider, | |
length_slider, | |
height_slider, | |
cfg_scale_slider, | |
seed_textbox, | |
], | |
outputs=[result_video] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = ui() | |
demo.launch(share=True) | |