Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
# the first flag below was False when we tested this script but True makes A100 training a lot faster: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
import os | |
from diffusers.models import AutoencoderKL | |
from models import FLAV_models | |
from diffusion.rectified_flow import RectifiedFlow | |
from diffusers.training_utils import EMAModel | |
from converter import Generator | |
from utils import * | |
import tempfile | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
AUDIO_T_PER_FRAME = 1600 // 160 | |
################################################################################# | |
# Global Model Setup # | |
################################################################################# | |
# These variables will be initialized in setup_models() and used in main() | |
vae = None | |
model = None | |
vocoder = None | |
audio_scale = 3.50 | |
def setup_models(): | |
global vae, model, vocoder | |
device = "cpu" | |
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema") | |
model = FLAV_models["FLAV-B/1"]( | |
latent_size= 256//8, | |
in_channels = 4, | |
num_classes = 0, | |
predict_frames = 10, | |
causal_attn = True, | |
) | |
ckpt_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="aist-ema.pth") | |
state_dict = torch.load(ckpt_path) | |
ema = EMAModel(model.parameters()) | |
ema.load_state_dict(state_dict) | |
ema.copy_to(model.parameters()) | |
hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/config.json") | |
vocoder_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/vocoder.pt") | |
vocoder_path = vocoder_path.replace("vocoder.pt", "") | |
vocoder = Generator.from_pretrained(vocoder_path) | |
vae.to(device) | |
model.to(device) | |
vocoder.to(device) | |
def generate_video(num_frames=10, steps=2, seed=42): | |
global vae, model, vocoder | |
# Setup device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.manual_seed(seed) | |
# Set up generation parameters | |
video_latent_size = (1, 10, 4, 256//8, 256//8) | |
audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME) | |
rectified_flow = RectifiedFlow(num_timesteps=steps, | |
warmup_timesteps=10, | |
window_size=10) | |
# Generate sample | |
video, audio = generate_sample( | |
vae=vae, # These globals are set by setup_models | |
rectified_flow=rectified_flow, | |
forward_fn=model.forward, | |
video_length=num_frames, | |
video_latent_size=video_latent_size, | |
audio_latent_size=audio_latent_size, | |
y=None, | |
cfg_scale=None, | |
device=device | |
) | |
# Convert to wav | |
wavs = get_wavs(audio, vocoder, audio_scale, device) | |
# Save to temporary files | |
temp_dir = tempfile.mkdtemp() | |
video_path = os.path.join(temp_dir, "video", "generated_video.mp4") | |
# Use the first video and wav | |
vid, wav = video[0], wavs[0] | |
save_multimodal(vid, wav, temp_dir, "generated") | |
return video_path | |
def ui_generate_video(num_frames, steps, seed): | |
try: | |
return generate_video(int(num_frames), int(steps), int(seed)) | |
except Exception as e: | |
return None | |
# Create Gradio interface | |
with gr.Blocks(title="FLAV Video Generator") as demo: | |
gr.Markdown("# FLAV Video Generator") | |
gr.Markdown("Generate videos using the FLAV model") | |
num_frames = None | |
steps = None | |
seed = None | |
video_output = None | |
with gr.Row(): | |
with gr.Column(): | |
num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames") | |
steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)") | |
seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed") | |
generate_btn = gr.Button("Generate Video") | |
with gr.Column(): | |
video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256) | |
generate_btn.click( | |
fn=ui_generate_video, | |
inputs=[num_frames, steps, seed], | |
outputs=[video_output] | |
) | |
if __name__ == "__main__": | |
setup_models() | |
demo.launch() | |