import torch import torchvision import os import os.path as osp import spaces import random from argparse import ArgumentParser from datetime import datetime import gradio as gr from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy from foleycrafter.pipelines.auffusion_pipeline import denormalize_spectrogram from foleycrafter.pipelines.auffusion_pipeline import Generator from foleycrafter.models.time_detector.model import VideoOnsetNet from foleycrafter.models.specvqgan.onset_baseline.utils import torch_utils from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from huggingface_hub import snapshot_download from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler import soundfile as sf from moviepy.editor import AudioFileClip, VideoFileClip os.environ['GRADIO_TEMP_DIR'] = './tmp' sample_idx = 0 scheduler_dict = { "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, "PNDM": PNDMScheduler, } css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ parser = ArgumentParser() parser.add_argument("--config", type=str, default="example/config/base.yaml") parser.add_argument("--server-name", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--share", type=bool, default=True) parser.add_argument("--save-path", default="samples") args = parser.parse_args() N_PROMPT = ( "" ) class FoleyController: def __init__(self): # config dirs self.basedir = os.getcwd() self.model_dir = os.path.join(self.basedir, "models") self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") os.makedirs(self.savedir, exist_ok=True) self.pipeline = None self.loaded = False self.load_model() def load_model(self): gr.Info("Start Load Models...") print("Start Load Models...") # download ckpt pretrained_model_name_or_path = 'auffusion/auffusion-full-no-adapter' if not os.path.isdir(pretrained_model_name_or_path): pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, local_dir='models/auffusion') fc_ckpt = 'ymzhang319/FoleyCrafter' if not os.path.isdir(fc_ckpt): fc_ckpt = snapshot_download(fc_ckpt, local_dir='models/') # set model config temporal_ckpt_path = osp.join(self.model_dir, 'temporal_adapter.ckpt') # load vocoder vocoder_config_path= "./models/auffusion" self.vocoder = Generator.from_pretrained( vocoder_config_path, subfolder="vocoder") # load time detector time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar')) time_detector = VideoOnsetNet(False) self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True) self.pipeline = build_foleycrafter() ckpt = torch.load(temporal_ckpt_path) # load temporal adapter if 'state_dict' in ckpt.keys(): ckpt = ckpt['state_dict'] load_gligen_ckpt = {} for key, value in ckpt.items(): if key.startswith('module.'): load_gligen_ckpt[key[len('module.'):]] = value else: load_gligen_ckpt[key] = value m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False) print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};") self.image_processor = CLIPImageProcessor() self.image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder') self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None) gr.Info("Load Finish!") print("Load Finish!") self.loaded = True return "Load" @spaces.GPU def foley( self, input_video, prompt_textbox, negative_prompt_textbox, ip_adapter_scale, temporal_scale, sampler_dropdown, sample_step_slider, cfg_scale_slider, seed_textbox, ): device = 'cuda' # move to gpu self.time_detector = controller.time_detector.to(device) self.pipeline = controller.pipeline.to(device) self.vocoder = controller.vocoder.to(device) self.image_encoder = controller.image_encoder.to(device) vision_transform_list = [ torchvision.transforms.Resize((128, 128)), torchvision.transforms.CenterCrop((112, 112)), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] video_transform = torchvision.transforms.Compose(vision_transform_list) # if not self.loaded: # raise gr.Error("Error with loading model") generator = torch.Generator() if seed_textbox != "": torch.manual_seed(int(seed_textbox)) generator.manual_seed(int(seed_textbox)) max_frame_nums = 150 frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums) if duration >= 10: duration = 10 time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(device) time_frames = video_transform(time_frames) time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)} preds = self.time_detector(time_frames) preds = torch.sigmoid(preds) # duration time_condition = [-1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1 for i in range(int(1024 / 10 * duration))] time_condition = time_condition + [-1] * (1024 - len(time_condition)) # w -> b c h w time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1) # Note that clip need fewer frames frames = frames[::10] images = self.image_processor(images=frames, return_tensors="pt").to(device) image_embeddings = self.image_encoder(**images).image_embeds image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0) neg_image_embeddings = torch.zeros_like(image_embeddings) image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1) self.pipeline.set_ip_adapter_scale(ip_adapter_scale) sample = self.pipeline( prompt=prompt_textbox, negative_prompt=negative_prompt_textbox, ip_adapter_image_embeds=image_embeddings, image=time_condition, controlnet_conditioning_scale=float(temporal_scale), num_inference_steps=sample_step_slider, height=256, width=1024, output_type="pt", generator=generator, ) name = 'output' audio_img = sample.images[0] audio = denormalize_spectrogram(audio_img) audio = self.vocoder.inference(audio, lengths=160000)[0] audio_save_path = osp.join(self.savedir_sample, 'audio') os.makedirs(audio_save_path, exist_ok=True) audio = audio[:int(duration * 16000)] save_path = osp.join(audio_save_path, f'{name}.wav') sf.write(save_path, audio, 16000) audio = AudioFileClip(osp.join(audio_save_path, f'{name}.wav')) video = VideoFileClip(input_video) audio = audio.subclip(0, duration) video.audio = audio video = video.subclip(0, duration) video.write_videofile(osp.join(self.savedir_sample, f'{name}.mp4')) save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4") return save_sample_path controller = FoleyController() device = "cuda" if torch.cuda.is_available() else "cpu" with gr.Blocks(css=css) as demo: gr.HTML( '
**Tips**:
\
# 1. With strong temporal visual cues in input video, you can scale up the **Temporal Align Scale**.
\
# 2. **Visual content scale** is the level of semantic alignment with visual content.