import os import torch import random import gradio as gr from glob import glob from omegaconf import OmegaConf from safetensors import safe_open from diffusers import AutoencoderKL from diffusers import EulerDiscreteScheduler, DDIMScheduler from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer from animatediff.models.unet import UNet3DConditionModel from animatediff.pipelines.pipeline_animation import AnimationFreeInitPipeline from animatediff.utils.util import save_videos_grid from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from diffusers.training_utils import set_seed from animatediff.utils.freeinit_utils import get_freq_filter from collections import namedtuple pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5" inference_config_path = "configs/inference/inference-v1.yaml" css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ examples = [ # 0-RealisticVision [ "realisticVisionV51_v20Novae.safetensors", "mm_sd_v14.ckpt", "A panda standing on a surfboard in the ocean under moonlight.", "worst quality, low quality, nsfw, logo", 512, 512, "2005563494988190", "butterworth", 0.25, 0.25, 3, ["use_fp16"] ], # 1-ToonYou [ "toonyou_beta3.safetensors", "mm_sd_v14.ckpt", "(best quality, masterpiece), 1girl, looking at viewer, blurry background, upper body, contemporary, dress", "(worst quality, low quality)", 512, 512, "478028150728261", "butterworth", 0.25, 0.25, 3, ["use_fp16"] ], # 2-Lyriel [ "lyriel_v16.safetensors", "mm_sd_v14.ckpt", "hypercars cyberpunk moving, muted colors, swirling color smokes, legend, cityscape, space", "3d, cartoon, anime, sketches, worst quality, low quality, nsfw, logo", 512, 512, "1566149281915957", "butterworth", 0.25, 0.25, 3, ["use_fp16"] ], # 3-RCNZ [ "rcnzCartoon3d_v10.safetensors", "mm_sd_v14.ckpt", "A cute raccoon playing guitar in a boat on the ocean", "worst quality, low quality, nsfw, logo", 512, 512, "1566149281915957", "butterworth", 0.25, 0.25, 3, ["use_fp16"] ], # 4-MajicMix [ "majicmixRealistic_v5Preview.safetensors", "mm_sd_v14.ckpt", "1girl, reading book", "(ng_deepnegative_v1_75t:1.2), (badhandv4:1), (worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands, watermark, moles", 512, 512, "2005563494988190", "butterworth", 0.25, 0.25, 3, ["use_fp16"] ], # # 5-RealisticVision # [ # "realisticVisionV51_v20Novae.safetensors", # "mm_sd_v14.ckpt", # "A panda standing on a surfboard in the ocean in sunset.", # "worst quality, low quality, nsfw, logo", # 512, 512, "2005563494988190", # "butterworth", 0.25, 0.25, 3, # ["use_fp16"] # ] ] # clean unrelated ckpts # ckpts = [ # "realisticVisionV40_v20Novae.safetensors", # "majicmixRealistic_v5Preview.safetensors", # "rcnzCartoon3d_v10.safetensors", # "lyriel_v16.safetensors", # "toonyou_beta3.safetensors" # ] # for path in glob(os.path.join("models", "DreamBooth_LoRA", "*.safetensors")): # for ckpt in ckpts: # if path.endswith(ckpt): break # else: # print(f"### Cleaning {path} ...") # os.system(f"rm -rf {path}") # os.system(f"rm -rf {os.path.join('models', 'DreamBooth_LoRA', '*.safetensors')}") # os.system(f"bash download_bashscripts/1-ToonYou.sh") # os.system(f"bash download_bashscripts/2-Lyriel.sh") # os.system(f"bash download_bashscripts/3-RcnzCartoon.sh") # os.system(f"bash download_bashscripts/4-MajicMix.sh") # os.system(f"bash download_bashscripts/5-RealisticVision.sh") # clean Gradio cache print(f"### Cleaning cached examples ...") os.system(f"rm -rf gradio_cached_examples/") class AnimateController: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") self.savedir = os.path.join(self.basedir, "samples") os.makedirs(self.savedir, exist_ok=True) self.base_model_list = [] self.motion_module_list = [] self.filter_type_list = [ "butterworth", "gaussian", "box", "ideal" ] self.selected_base_model = None self.selected_motion_module = None self.selected_filter_type = None self.set_width = None self.set_height = None self.set_d_s = None self.set_d_t = None self.refresh_motion_module() self.refresh_personalized_model() # config models self.inference_config = OmegaConf.load(inference_config_path) self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda() self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda() self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() self.freq_filter = None self.update_base_model(self.base_model_list[-2]) self.update_motion_module(self.motion_module_list[0]) self.update_filter(512, 512, self.filter_type_list[0], 0.25, 0.25) def refresh_motion_module(self): motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) self.motion_module_list = sorted([os.path.basename(p) for p in motion_module_list]) def refresh_personalized_model(self): base_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) self.base_model_list = sorted([os.path.basename(p) for p in base_model_list]) def update_base_model(self, base_model_dropdown): self.selected_base_model = base_model_dropdown base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) base_model_state_dict = {} with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key) converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config) self.vae.load_state_dict(converted_vae_checkpoint) converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config) self.unet.load_state_dict(converted_unet_checkpoint, strict=False) self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) return gr.Dropdown.update() def update_motion_module(self, motion_module_dropdown): self.selected_motion_module = motion_module_dropdown motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) assert len(unexpected) == 0 return gr.Dropdown.update() # def update_filter(self, shape, method, n, d_s, d_t): def update_filter(self, width_slider, height_slider, filter_type_dropdown, d_s_slider, d_t_slider): self.set_width = width_slider self.set_height = height_slider self.selected_filter_type = filter_type_dropdown self.set_d_s = d_s_slider self.set_d_t = d_t_slider vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) shape = [1, 4, 16, self.set_width//vae_scale_factor, self.set_height//vae_scale_factor] self.freq_filter = get_freq_filter( shape, device="cuda", filter_type=self.selected_filter_type, n=4, d_s=self.set_d_s, d_t=self.set_d_t ) def animate( self, base_model_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox, # freeinit params filter_type_dropdown, d_s_slider, d_t_slider, num_iters_slider, # speed up speed_up_options ): # set global seed set_seed(42) d_s = float(d_s_slider) d_t = float(d_t_slider) num_iters = int(num_iters_slider) if self.selected_base_model != base_model_dropdown: self.update_base_model(base_model_dropdown) if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown) self.set_width = width_slider self.set_height = height_slider self.selected_filter_type = filter_type_dropdown self.set_d_s = d_s self.set_d_t = d_t if self.set_width != width_slider or self.set_height != height_slider or self.selected_filter_type != filter_type_dropdown or self.set_d_s != d_s or self.set_d_t != d_t: self.update_filter(width_slider, height_slider, filter_type_dropdown, d_s, d_t) if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() pipeline = AnimationFreeInitPipeline( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) ).to("cuda") # (freeinit) initialize frequency filter for noise reinitialization ------------- pipeline.freq_filter = self.freq_filter # ------------------------------------------------------------------------------- if int(seed_textbox) > 0: seed = int(seed_textbox) else: seed = random.randint(1, 1e16) torch.manual_seed(int(seed)) assert seed == torch.initial_seed() print(f"### seed: {seed}") generator = torch.Generator(device="cuda") generator.manual_seed(seed) sample_output = pipeline( prompt_textbox, negative_prompt = negative_prompt_textbox, num_inference_steps = 25, guidance_scale = 7.5, width = width_slider, height = height_slider, video_length = 16, num_iters = num_iters, use_fast_sampling = True if "use_coarse_to_fine_sampling" in speed_up_options else False, save_intermediate = False, return_orig = True, use_fp16 = True if "use_fp16" in speed_up_options else False ) orig_sample = sample_output.orig_videos sample = sample_output.videos save_sample_path = os.path.join(self.savedir, f"sample.mp4") save_videos_grid(sample, save_sample_path) save_orig_sample_path = os.path.join(self.savedir, f"sample_orig.mp4") save_videos_grid(orig_sample, save_orig_sample_path) # save_compare_path = os.path.join(self.savedir, f"compare.mp4") # save_videos_grid(torch.concat([orig_sample, sample]), save_compare_path) json_config = { "prompt": prompt_textbox, "n_prompt": negative_prompt_textbox, "width": width_slider, "height": height_slider, "seed": seed, "base_model": base_model_dropdown, "motion_module": motion_module_dropdown, "filter_type": filter_type_dropdown, "d_s": d_s, "d_t": d_t, "num_iters": num_iters, "use_fp16": True if "use_fp16" in speed_up_options else False, "use_coarse_to_fine_sampling": True if "use_coarse_to_fine_sampling" in speed_up_options else False } # return gr.Video.update(value=save_compare_path), gr.Json.update(value=json_config) # return gr.Video.update(value=save_orig_sample_path), gr.Video.update(value=save_sample_path), gr.Video.update(value=save_compare_path), gr.Json.update(value=json_config) return gr.Video.update(value=save_orig_sample_path), gr.Video.update(value=save_sample_path), gr.Json.update(value=json_config) controller = AnimateController() def ui(): with gr.Blocks(css=css) as demo: # gr.Markdown('# FreeInit') gr.Markdown( """