|
|
|
import os |
|
from io import BytesIO |
|
import base64 |
|
from functools import partial |
|
|
|
from PIL import Image, ImageOps |
|
import gradio as gr |
|
|
|
from makeavid_sd.inference import InferenceUNetPseudo3D, FlaxDPMSolverMultistepScheduler, jnp |
|
|
|
|
|
_preheat: bool = False |
|
|
|
_seen_compilations = set() |
|
|
|
_model = InferenceUNetPseudo3D( |
|
model_path = 'TempoFunk/makeavid-sd-jax', |
|
scheduler_cls = FlaxDPMSolverMultistepScheduler, |
|
dtype = jnp.float16, |
|
hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None) |
|
) |
|
|
|
|
|
def generate( |
|
prompt = 'An elderly man having a great time in the park.', |
|
neg_prompt = '', |
|
image = { 'image': None, 'mask': None }, |
|
inference_steps = 20, |
|
cfg = 12.0, |
|
seed = 0, |
|
fps = 24, |
|
num_frames = 24, |
|
height = 512, |
|
width = 512 |
|
) -> str: |
|
height = int(height) |
|
width = int(width) |
|
num_frames = int(num_frames) |
|
seed = int(seed) |
|
if seed < 0: |
|
seed = -seed |
|
inference_steps = int(inference_steps) |
|
if image is not None: |
|
hint_image = image['image'] |
|
mask_image = image['mask'] |
|
else: |
|
hint_image = None |
|
mask_image = None |
|
if hint_image is not None: |
|
if hint_image.mode != 'RGB': |
|
hint_image = hint_image.convert('RGB') |
|
if hint_image.size != (width, height): |
|
hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS) |
|
if mask_image is not None: |
|
if mask_image.mode != 'L': |
|
mask_image = mask_image.convert('L') |
|
if mask_image.size != (width, height): |
|
mask_image = ImageOps.fit(mask_image, (width, height), method = Image.Resampling.LANCZOS) |
|
images = _model.generate( |
|
prompt = [prompt] * _model.device_count, |
|
neg_prompt = neg_prompt, |
|
hint_image = hint_image, |
|
mask_image = mask_image, |
|
inference_steps = inference_steps, |
|
cfg = cfg, |
|
height = height, |
|
width = width, |
|
num_frames = num_frames, |
|
seed = seed |
|
) |
|
_seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames)) |
|
buffer = BytesIO() |
|
images[0].save( |
|
buffer, |
|
format = 'webp', |
|
save_all = True, |
|
append_images = images[1:], |
|
loop = 0, |
|
duration = round(1000 / fps), |
|
allow_mixed = True |
|
) |
|
data = base64.b64encode(buffer.getvalue()).decode() |
|
data = 'data:image/webp;base64,' + data |
|
buffer.close() |
|
return data |
|
|
|
def check_if_compiled(image, inference_steps, height, width, num_frames, message): |
|
height = int(height) |
|
width = int(width) |
|
hint_image = None if image is None else image['image'] |
|
if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations: |
|
return '' |
|
else: |
|
return f"""{message}""" |
|
|
|
if _preheat: |
|
print('\npreheating the oven') |
|
generate( |
|
prompt = 'preheating the oven', |
|
neg_prompt = '', |
|
image = { 'image': None, 'mask': None }, |
|
inference_steps = 20, |
|
cfg = 12.0, |
|
seed = 0 |
|
) |
|
print('Entertaining the guests with sailor songs played on an old piano.') |
|
dada = generate( |
|
prompt = 'Entertaining the guests with sailor songs played on an old harmonium.', |
|
neg_prompt = '', |
|
image = { 'image': Image.new('RGB', size = (512, 512), color = (0, 0, 0)), 'mask': None }, |
|
inference_steps = 20, |
|
cfg = 12.0, |
|
seed = 0 |
|
) |
|
print('dinner is ready\n') |
|
|
|
with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo: |
|
variant = 'panel' |
|
with gr.Row(): |
|
with gr.Column(): |
|
intro1 = gr.Markdown(""" |
|
# Make-A-Video Stable Diffusion JAX |
|
**Please be patient. The model might have to compile with current parameters.** |
|
|
|
This can take up to 5 minutes on the first run, and 2-3 minutes on later runs. |
|
The compilation will be cached and consecutive runs with the same parameters |
|
will be much faster. |
|
""") |
|
with gr.Column(): |
|
intro2 = gr.Markdown(""" |
|
The following parameters require the model to compile |
|
- Number of frames |
|
- Width & Height |
|
- Steps |
|
- Input image vs. no input image |
|
""") |
|
|
|
with gr.Row(variant = variant): |
|
with gr.Column(variant = variant): |
|
with gr.Row(): |
|
cancel_button = gr.Button(value = 'Cancel') |
|
submit_button = gr.Button(value = 'Make A Video', variant = 'primary') |
|
prompt_input = gr.Textbox( |
|
label = 'Prompt', |
|
value = 'They are dancing in the club while sweat drips from the ceiling.', |
|
interactive = True |
|
) |
|
neg_prompt_input = gr.Textbox( |
|
label = 'Negative prompt (optional)', |
|
value = '', |
|
interactive = True |
|
) |
|
inference_steps_input = gr.Slider( |
|
label = 'Steps', |
|
minimum = 1, |
|
maximum = 100, |
|
value = 20, |
|
step = 1 |
|
) |
|
cfg_input = gr.Slider( |
|
label = 'Guidance scale', |
|
minimum = 1.0, |
|
maximum = 20.0, |
|
step = 0.1, |
|
value = 15.0, |
|
interactive = True |
|
) |
|
seed_input = gr.Number( |
|
label = 'Random seed', |
|
value = 0, |
|
interactive = True, |
|
precision = 0 |
|
) |
|
image_input = gr.Image( |
|
label = 'Input image (optional)', |
|
interactive = True, |
|
image_mode = 'RGB', |
|
type = 'pil', |
|
optional = True, |
|
source = 'upload', |
|
tool = 'sketch' |
|
) |
|
num_frames_input = gr.Slider( |
|
label = 'Number of frames to generate', |
|
minimum = 1, |
|
maximum = 24, |
|
step = 1, |
|
value = 24 |
|
) |
|
width_input = gr.Slider( |
|
label = 'Width', |
|
minimum = 64, |
|
maximum = 512, |
|
step = 1, |
|
value = 448 |
|
) |
|
height_input = gr.Slider( |
|
label = 'Height', |
|
minimum = 64, |
|
maximum = 512, |
|
step = 1, |
|
value = 448 |
|
) |
|
fps_input = gr.Slider( |
|
label = 'Output FPS', |
|
minimum = 1, |
|
maximum = 1000, |
|
step = 1, |
|
value = 12 |
|
) |
|
with gr.Column(variant = variant): |
|
will_trigger = gr.Markdown('') |
|
patience = gr.Markdown('') |
|
image_output = gr.Image( |
|
label = 'Output', |
|
value = 'example.webp', |
|
interactive = False |
|
) |
|
trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ] |
|
trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.') |
|
height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) |
|
width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) |
|
num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) |
|
inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) |
|
will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value) |
|
ev = submit_button.click( |
|
fn = partial( |
|
check_if_compiled, |
|
message = 'Please be patient. The model has to be compiled with current parameters.' |
|
), |
|
inputs = trigger_inputs, |
|
outputs = patience |
|
).then( |
|
fn = generate, |
|
inputs = [ |
|
prompt_input, |
|
neg_prompt_input, |
|
image_input, |
|
inference_steps_input, |
|
cfg_input, |
|
seed_input, |
|
fps_input, |
|
num_frames_input, |
|
height_input, |
|
width_input |
|
], |
|
outputs = image_output, |
|
postprocess = False |
|
).then( |
|
fn = trigger_check_fun, |
|
inputs = trigger_inputs, |
|
outputs = will_trigger |
|
) |
|
cancel_button(cancels = ev) |
|
|
|
demo.queue(concurrency_count = 1, max_size = 16) |
|
demo.launch() |
|
|
|
|