Spaces:
Runtime error
Runtime error
import os | |
from io import BytesIO | |
import base64 | |
from functools import partial | |
from PIL import Image, ImageOps | |
import gradio as gr | |
from makeavid_sd.inference import InferenceUNetPseudo3D, FlaxDPMSolverMultistepScheduler, jnp | |
_preheat: bool = False | |
_seen_compilations = set() | |
_model = InferenceUNetPseudo3D( | |
model_path = 'TempoFunk/makeavid-sd-jax', | |
scheduler_cls = FlaxDPMSolverMultistepScheduler, | |
dtype = jnp.float16, | |
hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None) | |
) | |
# gradio is illiterate. type hints make it go poopoo in pantsu. | |
def generate( | |
prompt = 'An elderly man having a great time in the park.', | |
neg_prompt = '', | |
image = { 'image': None, 'mask': None }, | |
inference_steps = 20, | |
cfg = 12.0, | |
seed = 0, | |
fps = 24, | |
num_frames = 24, | |
height = 512, | |
width = 512 | |
) -> str: | |
height = int(height) | |
width = int(width) | |
num_frames = int(num_frames) | |
seed = int(seed) | |
if seed < 0: | |
seed = -seed | |
inference_steps = int(inference_steps) | |
if image is not None: | |
hint_image = image['image'] | |
mask_image = image['mask'] | |
else: | |
hint_image = None | |
mask_image = None | |
if hint_image is not None: | |
if hint_image.mode != 'RGB': | |
hint_image = hint_image.convert('RGB') | |
if hint_image.size != (width, height): | |
hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS) | |
if mask_image is not None: | |
if mask_image.mode != 'L': | |
mask_image = mask_image.convert('L') | |
if mask_image.size != (width, height): | |
mask_image = ImageOps.fit(mask_image, (width, height), method = Image.Resampling.LANCZOS) | |
images = _model.generate( | |
prompt = [prompt] * _model.device_count, | |
neg_prompt = neg_prompt, | |
hint_image = hint_image, | |
mask_image = mask_image, | |
inference_steps = inference_steps, | |
cfg = cfg, | |
height = height, | |
width = width, | |
num_frames = num_frames, | |
seed = seed | |
) | |
_seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames)) | |
buffer = BytesIO() | |
images[0].save( | |
buffer, | |
format = 'webp', | |
save_all = True, | |
append_images = images[1:], | |
loop = 0, | |
duration = round(1000 / fps), | |
allow_mixed = True | |
) | |
data = base64.b64encode(buffer.getvalue()).decode() | |
data = 'data:image/webp;base64,' + data | |
buffer.close() | |
return data | |
def check_if_compiled(image, inference_steps, height, width, num_frames, message): | |
height = int(height) | |
width = int(width) | |
hint_image = None if image is None else image['image'] | |
if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations: | |
return '' | |
else: | |
return f"""{message}""" | |
if _preheat: | |
print('\npreheating the oven') | |
generate( | |
prompt = 'preheating the oven', | |
neg_prompt = '', | |
image = { 'image': None, 'mask': None }, | |
inference_steps = 20, | |
cfg = 12.0, | |
seed = 0 | |
) | |
print('Entertaining the guests with sailor songs played on an old piano.') | |
dada = generate( | |
prompt = 'Entertaining the guests with sailor songs played on an old harmonium.', | |
neg_prompt = '', | |
image = { 'image': Image.new('RGB', size = (512, 512), color = (0, 0, 0)), 'mask': None }, | |
inference_steps = 20, | |
cfg = 12.0, | |
seed = 0 | |
) | |
print('dinner is ready\n') | |
with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo: | |
variant = 'panel' | |
with gr.Row(): | |
with gr.Column(): | |
intro1 = gr.Markdown(""" | |
# Make-A-Video Stable Diffusion JAX | |
We have extended a pretrained LDM inpainting image generation model with temporal convolutions and attention. | |
We take advantage of the extra 5 input channels of the inpaint model to guide the video generation with a hint image and mask. | |
The hint image can be given by the user, otherwise it is generated by an generative image model. | |
The temporal convolution and attention is a port of [Make-A-Video Pytorch](https://github.com/lucidrains/make-a-video-pytorch/blob/main/make_a_video_pytorch) to FLAX. | |
It is a pseudo 3D convolution that seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D. | |
Temporal attention is purely self attention and also separately attends to time and space. | |
Only the new temporal layers have been fine tuned on a dataset of videos themed around dance. | |
The model has been trained for 60 epochs on a dataset of 10,000 Videos with 120 frames each, randomly selecting a 24 frame range from each sample. | |
See model and dataset links in the metadata. | |
Model implementation and training code can be found at [https://github.com/lopho/makeavid-sd-tpu](https://github.com/lopho/makeavid-sd-tpu) | |
""") | |
with gr.Column(): | |
intro3 = gr.Markdown(""" | |
**Please be patient. The model might have to compile with current parameters.** | |
This can take up to 5 minutes on the first run, and 2-3 minutes on later runs. | |
The compilation will be cached and consecutive runs with the same parameters | |
will be much faster. | |
Changes to the following parameters require the model to compile | |
- Number of frames | |
- Width & Height | |
- Steps | |
- Input image vs. no input image | |
""") | |
with gr.Row(variant = variant): | |
with gr.Column(variant = variant): | |
with gr.Row(): | |
#cancel_button = gr.Button(value = 'Cancel') | |
submit_button = gr.Button(value = 'Make A Video', variant = 'primary') | |
prompt_input = gr.Textbox( | |
label = 'Prompt', | |
value = 'They are dancing in the club while sweat drips from the ceiling.', | |
interactive = True | |
) | |
neg_prompt_input = gr.Textbox( | |
label = 'Negative prompt (optional)', | |
value = '', | |
interactive = True | |
) | |
inference_steps_input = gr.Slider( | |
label = 'Steps', | |
minimum = 2, | |
maximum = 100, | |
value = 20, | |
step = 1 | |
) | |
cfg_input = gr.Slider( | |
label = 'Guidance scale', | |
minimum = 1.0, | |
maximum = 20.0, | |
step = 0.1, | |
value = 15.0, | |
interactive = True | |
) | |
seed_input = gr.Number( | |
label = 'Random seed', | |
value = 0, | |
interactive = True, | |
precision = 0 | |
) | |
image_input = gr.Image( | |
label = 'Input image (optional)', | |
interactive = True, | |
image_mode = 'RGB', | |
type = 'pil', | |
optional = True, | |
source = 'upload', | |
tool = 'sketch' | |
) | |
num_frames_input = gr.Slider( | |
label = 'Number of frames to generate', | |
minimum = 1, | |
maximum = 24, | |
step = 1, | |
value = 24 | |
) | |
width_input = gr.Slider( | |
label = 'Width', | |
minimum = 64, | |
maximum = 512, | |
step = 1, | |
value = 448 | |
) | |
height_input = gr.Slider( | |
label = 'Height', | |
minimum = 64, | |
maximum = 512, | |
step = 1, | |
value = 448 | |
) | |
fps_input = gr.Slider( | |
label = 'Output FPS', | |
minimum = 1, | |
maximum = 1000, | |
step = 1, | |
value = 12 | |
) | |
with gr.Column(variant = variant): | |
will_trigger = gr.Markdown('') | |
patience = gr.Markdown('') | |
image_output = gr.Image( | |
label = 'Output', | |
value = 'example.webp', | |
interactive = False | |
) | |
trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ] | |
trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.') | |
height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value) | |
ev = submit_button.click( | |
fn = partial( | |
check_if_compiled, | |
message = 'Please be patient. The model has to be compiled with current parameters.' | |
), | |
inputs = trigger_inputs, | |
outputs = patience | |
).then( | |
fn = generate, | |
inputs = [ | |
prompt_input, | |
neg_prompt_input, | |
image_input, | |
inference_steps_input, | |
cfg_input, | |
seed_input, | |
fps_input, | |
num_frames_input, | |
height_input, | |
width_input | |
], | |
outputs = image_output, | |
postprocess = False | |
).then( | |
fn = trigger_check_fun, | |
inputs = trigger_inputs, | |
outputs = will_trigger | |
) | |
#cancel_button.click(fn = lambda: None, cancels = ev) | |
demo.queue(concurrency_count = 1, max_size = 32) | |
demo.launch() | |