R-FLAV / app.py
Alex Ergasti
Init
b89c182
raw
history blame
4.32 kB
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import os
from diffusers.models import AutoencoderKL
from models import FLAV_models
from diffusion.rectified_flow import RectifiedFlow
from diffusers.training_utils import EMAModel
from converter import Generator
from utils import *
import tempfile
import gradio as gr
from huggingface_hub import hf_hub_download
AUDIO_T_PER_FRAME = 1600 // 160
#################################################################################
# Global Model Setup #
#################################################################################
# These variables will be initialized in setup_models() and used in main()
vae = None
model = None
vocoder = None
audio_scale = 3.50
def setup_models():
global vae, model, vocoder
device = "cpu"
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
model = FLAV_models["FLAV-B/1"](
latent_size= 256//8,
in_channels = 4,
num_classes = 0,
predict_frames = 10,
causal_attn = True,
)
ckpt_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="aist-ema.pth")
state_dict = torch.load(ckpt_path)
ema = EMAModel(model.parameters())
ema.load_state_dict(state_dict)
ema.copy_to(model.parameters())
hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/config.json")
vocoder_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/vocoder.pt")
vocoder_path = vocoder_path.replace("vocoder.pt", "")
vocoder = Generator.from_pretrained(vocoder_path)
vae.to(device)
model.to(device)
vocoder.to(device)
def generate_video(num_frames=10, steps=2, seed=42):
global vae, model, vocoder
# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(seed)
# Set up generation parameters
video_latent_size = (1, 10, 4, 256//8, 256//8)
audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME)
rectified_flow = RectifiedFlow(num_timesteps=steps,
warmup_timesteps=10,
window_size=10)
# Generate sample
video, audio = generate_sample(
vae=vae, # These globals are set by setup_models
rectified_flow=rectified_flow,
forward_fn=model.forward,
video_length=num_frames,
video_latent_size=video_latent_size,
audio_latent_size=audio_latent_size,
y=None,
cfg_scale=None,
device=device
)
# Convert to wav
wavs = get_wavs(audio, vocoder, audio_scale, device)
# Save to temporary files
temp_dir = tempfile.mkdtemp()
video_path = os.path.join(temp_dir, "video", "generated_video.mp4")
# Use the first video and wav
vid, wav = video[0], wavs[0]
save_multimodal(vid, wav, temp_dir, "generated")
return video_path
def ui_generate_video(num_frames, steps, seed):
try:
return generate_video(int(num_frames), int(steps), int(seed))
except Exception as e:
return None
# Create Gradio interface
with gr.Blocks(title="FLAV Video Generator") as demo:
gr.Markdown("# FLAV Video Generator")
gr.Markdown("Generate videos using the FLAV model")
num_frames = None
steps = None
seed = None
video_output = None
with gr.Row():
with gr.Column():
num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames")
steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)")
seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
generate_btn = gr.Button("Generate Video")
with gr.Column():
video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256)
generate_btn.click(
fn=ui_generate_video,
inputs=[num_frames, steps, seed],
outputs=[video_output]
)
if __name__ == "__main__":
setup_models()
demo.launch()