#!/usr/bin/env python
from __future__ import annotations
import os
import gradio as gr
# from inference import InferencePipeline
# from FateZero import test_fatezero
from inference_fatezero import merge_config_then_run
# class InferenceUtil:
# def __init__(self, hf_token: str | None):
# self.hf_token = hf_token
# def load_model_info(self, model_id: str) -> tuple[str, str]:
# # todo FIXME
# try:
# card = InferencePipeline.get_model_card(model_id, self.hf_token)
# except Exception:
# return '', ''
# base_model = getattr(card.data, 'base_model', '')
# training_prompt = getattr(card.data, 'training_prompt', '')
# return base_model, training_prompt
# TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
HF_TOKEN = os.getenv('HF_TOKEN')
# pipe = InferencePipeline(HF_TOKEN)
pipe = merge_config_then_run()
# app = InferenceUtil(HF_TOKEN)
with gr.Blocks(css='style.css') as demo:
# gr.Markdown(TITLE)
gr.HTML(
"""
FateZero : Fusing Attentions for Zero-shot Text-based Video Editing
FateZero is a first zero-shot framework for text-driven video editing via pretrained diffusion models without training.
""")
gr.HTML("""
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
""")
with gr.Row():
with gr.Column():
with gr.Accordion('Input Video', open=True):
user_input_video = gr.File(label='Input Source Video')
with gr.Accordion('Temporal Crop offset and Sampling Stride', open=False):
n_sample_frame = gr.Slider(label='Number of Frames in Video',
minimum=0,
maximum=32,
step=1,
value=8)
stride = gr.Slider(label='Temporal sampling stride in Video',
minimum=0,
maximum=20,
step=1,
value=1)
start_sample_frame = gr.Number(label='Start frame in the video',
value=0,
precision=0)
with gr.Accordion('Spatial Crop offset', open=False):
left_crop = gr.Number(label='Left crop',
value=0,
precision=0)
right_crop = gr.Number(label='Right crop',
value=0,
precision=0)
top_crop = gr.Number(label='Top crop',
value=0,
precision=0)
bottom_crop = gr.Number(label='Bottom crop',
value=0,
precision=0)
offset_list = [
left_crop,
right_crop,
top_crop,
bottom_crop,
]
ImageSequenceDataset_list = [
start_sample_frame,
n_sample_frame,
stride
] + offset_list
data_path = gr.Dropdown(
label='provided data path',
choices=[
'FateZero/data/teaser_car-turn',
'FateZero/data/style/sunflower',
# add shape editing ckpt here
],
value='FateZero/data/teaser_car-turn')
model_id = gr.Dropdown(
label='Model ID',
choices=[
'CompVis/stable-diffusion-v1-4',
# add shape editing ckpt here
],
value='CompVis/stable-diffusion-v1-4')
# with gr.Accordion(
# label=
# 'Model info (Base model and prompt used for training)',
# open=False):
# with gr.Row():
# base_model_used_for_training = gr.Text(
# label='Base model', interactive=False)
# prompt_used_for_training = gr.Text(
# label='Training prompt', interactive=False)
with gr.Accordion('Text Prompt', open=True):
source_prompt = gr.Textbox(label='Source Prompt',
info='A good prompt describes each frame and most objects in video. Especially, it has the object or attribute that we want to edit or preserve.',
max_lines=1,
placeholder='Example: "a silver jeep driving down a curvy road in the countryside"',
value='a silver jeep driving down a curvy road in the countryside')
target_prompt = gr.Textbox(label='Target Prompt',
info='A reasonable composition of video may achieve better results(e.g., "sunflower" video with "Van Gogh" prompt is better than "sunflower" with "Monet")',
max_lines=1,
placeholder='Example: "watercolor painting of a silver jeep driving down a curvy road in the countryside"',
value='watercolor painting of a silver jeep driving down a curvy road in the countryside')
with gr.Accordion('DDIM Parameters', open=True):
num_steps = gr.Slider(label='Number of Steps',
info='larger value has better editing capacity, but takes more time and memory',
minimum=0,
maximum=50,
step=1,
value=10)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
run_button = gr.Button('Generate')
# gr.Markdown('''
# - It takes a few minutes to download model first.
# - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
# ''')
# gr.Markdown('''
# todo
# ''')
with gr.Column():
result = gr.Video(label='Result')
result.style(height=512, width=512)
with gr.Accordion('FateZero Parameters for attention fusing', open=True):
cross_replace_steps = gr.Slider(label='cross-attention replace steps',
info='More steps, replace more cross attention to preserve semantic layout.',
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7)
self_replace_steps = gr.Slider(label='self-attention replace steps',
info='More steps, replace more spatial-temporal self-attention to preserve geometry and motion.',
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7)
enhance_words = gr.Textbox(label='words to be enhanced',
info='Amplify the target-words cross attention',
max_lines=1,
placeholder='Example: "watercolor "',
value='watercolor')
enhance_words_value = gr.Slider(label='Amplify the target cross-attention',
info='larger value, more elements of target words',
minimum=0.0,
maximum=20.0,
step=1,
value=10)
with gr.Row():
from example import style_example
examples = style_example
# examples = [
# [
# 'CompVis/stable-diffusion-v1-4',
# 'FateZero/data/teaser_car-turn',
# 'a silver jeep driving down a curvy road in the countryside',
# 'watercolor painting of a silver jeep driving down a curvy road in the countryside',
# 0.8,
# 0.8,
# "watercolor",
# 10,
# 10,
# 7.5,
# ],
# [
# 'CompVis/stable-diffusion-v1-4',
# 'FateZero/data/style/sunflower',
# 'a yellow sunflower',
# 'van gogh style painting of a yellow sunflower',
# 0.5,
# 0.5,
# 'van gogh',
# 10,
# 10,
# 7.5,
# ],
# ]
gr.Examples(examples=examples,
inputs=[
model_id,
data_path,
source_prompt,
target_prompt,
cross_replace_steps,
self_replace_steps,
enhance_words,
enhance_words_value,
num_steps,
guidance_scale,
user_input_video,
*ImageSequenceDataset_list
],
outputs=result,
fn=pipe.run,
cache_examples=os.getenv('SYSTEM') == 'spaces')
# model_id.change(fn=app.load_model_info,
# inputs=model_id,
# outputs=[
# base_model_used_for_training,
# prompt_used_for_training,
# ])
inputs = [
model_id,
data_path,
source_prompt,
target_prompt,
cross_replace_steps,
self_replace_steps,
enhance_words,
enhance_words_value,
num_steps,
guidance_scale,
user_input_video,
*ImageSequenceDataset_list
]
# prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
target_prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
# run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
demo.queue().launch()