#!/usr/bin/env python from __future__ import annotations import argparse import os import glob import pickle import sys import importlib from typing import List, Tuple import gradio as gr import numpy as np import torch import torch.nn as nn from beat_interpolator import beat_interpolator def build_models(): modules = glob.glob('examples/models/*') modules = [ getattr( importlib.import_module( module.replace('/', '.'), package=None ), 'create' )() for module in modules if '.py' not in module and '__' not in module ] attrs = [ (module['name'], module) for module in modules ] mnist_idx = -1 for i in range(len(attrs)): name, _ = attrs[i] if name == 'MNIST': mnist_idx = i if mnist_idx > -1: mnist_attr = attrs.pop(mnist_idx) attrs.insert(0, mnist_attr) return attrs def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') return parser.parse_args() def main(): args = parse_args() enable_queue = args.enable_queue model_attrs = build_models() with gr.Blocks(theme=args.theme) as demo: gr.Markdown('''

Beat-Interpolator

Play DL models with music beats.


This is a Gradio Blocks app of HighCWu/beat-interpolator. ''') with gr.Tabs(): for name, model_attr in model_attrs: with gr.TabItem(name): generator = model_attr['generator'] latent_dim = model_attr['latent_dim'] default_fps = model_attr['fps'] max_fps = model_attr['fps'] if enable_queue else 60 batch_size = model_attr['batch_size'] strength = model_attr['strength'] default_max_duration = model_attr['max_duration'] max_duration = model_attr['max_duration'] if enable_queue else 360 use_peak = model_attr['use_peak'] def build_interpolate( generator, latent_dim, batch_size ): def interpolate( wave_path, seed, fps=default_fps, strength=strength, max_duration=default_max_duration, use_peak=use_peak): return beat_interpolator( wave_path, generator, latent_dim, int(seed), int(fps), batch_size, strength, max_duration, use_peak) return interpolate interpolate = build_interpolate(generator, latent_dim, batch_size) with gr.Row(): with gr.Box(): with gr.Column(): with gr.Row(): wave_in = gr.Audio( type="filepath", label="Music" ) with gr.Row(): example_audios = gr.Dataset( components=[wave_in], samples=[['examples/example.mp3']] ) with gr.Row(): seed_in = gr.Number( value=128, label='Seed' ) with gr.Row(): fps_in = gr.Slider( value=default_fps, minimum=4, maximum=max_fps, label="FPS" ) with gr.Row(): strength_in = gr.Slider( value=strength, maximum=1, label="Strength" ) with gr.Row(): max_duration_in = gr.Slider( value=default_max_duration, minimum=5, maximum=max_duration, label="Max Duration" ) with gr.Row(): peak_in = gr.Checkbox(value=use_peak, label="Use peak") with gr.Row(): generate_button = gr.Button('Generate') with gr.Box(): with gr.Column(): with gr.Row(): interpolated_video = gr.Video(label='Output Video') generate_button.click(interpolate, inputs=[ wave_in, seed_in, fps_in, strength_in, max_duration_in, peak_in ], outputs=[interpolated_video]) example_audios.click( fn=lambda examples: gr.Audio.update(value=examples[0]), inputs=example_audios, outputs=example_audios.components ) gr.Markdown( '
visitor badge
' ) demo.launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()