import os from io import BytesIO import base64 from functools import partial from PIL import Image, ImageOps import gradio as gr from makeavid_sd.inference import InferenceUNetPseudo3D, FlaxDPMSolverMultistepScheduler, jnp _preheat: bool = False _seen_compilations = set() _model = InferenceUNetPseudo3D( model_path = 'TempoFunk/makeavid-sd-jax', scheduler_cls = FlaxDPMSolverMultistepScheduler, dtype = jnp.float16, hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None) ) # gradio is illiterate. type hints make it go poopoo in pantsu. def generate( prompt = 'An elderly man having a great time in the park.', neg_prompt = '', image = { 'image': None, 'mask': None }, inference_steps = 20, cfg = 12.0, seed = 0, fps = 24, num_frames = 24, height = 512, width = 512 ) -> str: height = int(height) width = int(width) num_frames = int(num_frames) seed = int(seed) if seed < 0: seed = -seed inference_steps = int(inference_steps) if image is not None: hint_image = image['image'] mask_image = image['mask'] else: hint_image = None mask_image = None if hint_image is not None: if hint_image.mode != 'RGB': hint_image = hint_image.convert('RGB') if hint_image.size != (width, height): hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS) if mask_image is not None: if mask_image.mode != 'L': mask_image = mask_image.convert('L') if mask_image.size != (width, height): mask_image = ImageOps.fit(mask_image, (width, height), method = Image.Resampling.LANCZOS) images = _model.generate( prompt = [prompt] * _model.device_count, neg_prompt = neg_prompt, hint_image = hint_image, mask_image = mask_image, inference_steps = inference_steps, cfg = cfg, height = height, width = width, num_frames = num_frames, seed = seed ) _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames)) buffer = BytesIO() images[0].save( buffer, format = 'webp', save_all = True, append_images = images[1:], loop = 0, duration = round(1000 / fps), allow_mixed = True ) data = base64.b64encode(buffer.getvalue()).decode() data = 'data:image/webp;base64,' + data buffer.close() return data def check_if_compiled(image, inference_steps, height, width, num_frames, message): height = int(height) width = int(width) hint_image = None if image is None else image['image'] if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations: return '' else: return f"""{message}""" if _preheat: print('\npreheating the oven') generate( prompt = 'preheating the oven', neg_prompt = '', image = { 'image': None, 'mask': None }, inference_steps = 20, cfg = 12.0, seed = 0 ) print('Entertaining the guests with sailor songs played on an old piano.') dada = generate( prompt = 'Entertaining the guests with sailor songs played on an old harmonium.', neg_prompt = '', image = { 'image': Image.new('RGB', size = (512, 512), color = (0, 0, 0)), 'mask': None }, inference_steps = 20, cfg = 12.0, seed = 0 ) print('dinner is ready\n') with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo: variant = 'panel' with gr.Row(): with gr.Column(): intro1 = gr.Markdown(""" # Make-A-Video Stable Diffusion JAX **Please be patient. The model might have to compile with current parameters.** This can take up to 5 minutes on the first run, and 2-3 minutes on later runs. The compilation will be cached and consecutive runs with the same parameters will be much faster. """) with gr.Column(): intro2 = gr.Markdown(""" The following parameters require the model to compile - Number of frames - Width & Height - Steps - Input image vs. no input image """) with gr.Row(variant = variant): with gr.Column(variant = variant): with gr.Row(): cancel_button = gr.Button(value = 'Cancel') submit_button = gr.Button(value = 'Make A Video', variant = 'primary') prompt_input = gr.Textbox( label = 'Prompt', value = 'They are dancing in the club while sweat drips from the ceiling.', interactive = True ) neg_prompt_input = gr.Textbox( label = 'Negative prompt (optional)', value = '', interactive = True ) inference_steps_input = gr.Slider( label = 'Steps', minimum = 1, maximum = 100, value = 20, step = 1 ) cfg_input = gr.Slider( label = 'Guidance scale', minimum = 1.0, maximum = 20.0, step = 0.1, value = 15.0, interactive = True ) seed_input = gr.Number( label = 'Random seed', value = 0, interactive = True, precision = 0 ) image_input = gr.Image( label = 'Input image (optional)', interactive = True, image_mode = 'RGB', type = 'pil', optional = True, source = 'upload', tool = 'sketch' ) num_frames_input = gr.Slider( label = 'Number of frames to generate', minimum = 1, maximum = 24, step = 1, value = 24 ) width_input = gr.Slider( label = 'Width', minimum = 64, maximum = 512, step = 1, value = 448 ) height_input = gr.Slider( label = 'Height', minimum = 64, maximum = 512, step = 1, value = 448 ) fps_input = gr.Slider( label = 'Output FPS', minimum = 1, maximum = 1000, step = 1, value = 12 ) with gr.Column(variant = variant): will_trigger = gr.Markdown('') patience = gr.Markdown('') image_output = gr.Image( label = 'Output', value = 'example.webp', interactive = False ) trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ] trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.') height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value) ev = submit_button.click( fn = partial( check_if_compiled, message = 'Please be patient. The model has to be compiled with current parameters.' ), inputs = trigger_inputs, outputs = patience ).then( fn = generate, inputs = [ prompt_input, neg_prompt_input, image_input, inference_steps_input, cfg_input, seed_input, fps_input, num_frames_input, height_input, width_input ], outputs = image_output, postprocess = False ).then( fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger ) cancel_button(cancels = ev) demo.queue(concurrency_count = 1, max_size = 16) demo.launch()