import torch # the first flag below was False when we tested this script but True makes A100 training a lot faster: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import os import spaces from diffusers.models import AutoencoderKL from models import FLAV_models from diffusion.rectified_flow import RectifiedFlow from diffusers.training_utils import EMAModel from converter import Generator from utils import * import tempfile import gradio as gr from huggingface_hub import hf_hub_download AUDIO_T_PER_FRAME = 1600 // 160 ################################################################################# # Global Model Setup # ################################################################################# # These variables will be initialized in setup_models() and used in main() vae = None model = None vocoder = None audio_scale = 3.50 def setup_models(): global vae, model, vocoder device = "cuda" vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema") vae.eval() model = FLAV_models["FLAV-B/1"]( latent_size= 256//8, in_channels = 4, num_classes = 0, predict_frames = 10, causal_attn = True, ) ckpt_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="aist-ema.pth") state_dict = torch.load(ckpt_path, map_location="cpu") ema = EMAModel(model.parameters()) ema.load_state_dict(state_dict) ema.copy_to(model.parameters()) hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/config.json") vocoder_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/vocoder.pt") vocoder_path = vocoder_path.replace("vocoder.pt", "") vocoder = Generator.from_pretrained(vocoder_path) vae.to(device) model.to(device) vocoder.to(device) @spaces.GPU def generate_video(num_frames=10, steps=2, seed=42): global vae, model, vocoder # Setup device device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(seed) # Set up generation parameters video_latent_size = (1, 10, 4, 256//8, 256//8) audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME) rectified_flow = RectifiedFlow(num_timesteps=steps, warmup_timesteps=10, window_size=10) # Generate sample video, audio = generate_sample( vae=vae, # These globals are set by setup_models rectified_flow=rectified_flow, forward_fn=model.forward, video_length=num_frames, video_latent_size=video_latent_size, audio_latent_size=audio_latent_size, y=None, cfg_scale=None, device=device ) # Convert to wav wavs = get_wavs(audio, vocoder, audio_scale, device) # Save to temporary files temp_dir = tempfile.mkdtemp() video_path = os.path.join(temp_dir, "video", "generated_video.mp4") # Use the first video and wav vid, wav = video[0], wavs[0] save_multimodal(vid, wav, temp_dir, "generated") return video_path def ui_generate_video(num_frames, steps, seed): try: return generate_video(int(num_frames), int(steps), int(seed)) except Exception as e: return None # Create Gradio interface with gr.Blocks(title="FLAV Video Generator") as demo: gr.Markdown("# FLAV Video Generator") gr.Markdown("Generate videos using the FLAV model") num_frames = None steps = None seed = None video_output = None with gr.Row(): with gr.Column(): num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames") steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)") seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed") generate_btn = gr.Button("Generate Video") with gr.Column(): video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256) generate_btn.click( fn=ui_generate_video, inputs=[num_frames, steps, seed], outputs=[video_output] ) if __name__ == "__main__": setup_models() demo.launch()