Tune-A-Video-Training-UI / app_training.py
hysts's picture
hysts HF staff
Migrate from yapf to black
ecfdc8b
#!/usr/bin/env python
from __future__ import annotations
import os
import gradio as gr
from constants import UploadTarget
from inference import InferencePipeline
from trainer import Trainer
def create_training_demo(
trainer: Trainer, pipe: InferencePipeline | None = None, disable_run_button: bool = False
) -> gr.Blocks:
def read_log() -> str:
with open(trainer.log_file) as f:
lines = f.readlines()
return "".join(lines[-10:])
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("Training Data")
training_video = gr.File(label="Training video")
training_prompt = gr.Textbox(label="Training prompt", max_lines=1, placeholder="A man is surfing")
gr.Markdown(
"""
- Upload a video and write a `Training Prompt` that describes the video.
"""
)
with gr.Column():
with gr.Box():
gr.Markdown("Training Parameters")
with gr.Row():
base_model = gr.Text(label="Base Model", value="CompVis/stable-diffusion-v1-4", max_lines=1)
resolution = gr.Dropdown(
choices=["512", "768"], value="512", label="Resolution", visible=False
)
hf_token = gr.Text(
label="Hugging Face Write Token", type="password", visible=os.getenv("HF_TOKEN") is None
)
with gr.Accordion(label="Advanced options", open=False):
num_training_steps = gr.Number(label="Number of Training Steps", value=300, precision=0)
learning_rate = gr.Number(label="Learning Rate", value=0.000035)
gradient_accumulation = gr.Number(
label="Number of Gradient Accumulation", value=1, precision=0
)
seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, randomize=True, value=0)
fp16 = gr.Checkbox(label="FP16", value=True)
use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=False)
checkpointing_steps = gr.Number(label="Checkpointing Steps", value=1000, precision=0)
validation_epochs = gr.Number(label="Validation Epochs", value=100, precision=0)
gr.Markdown(
"""
- The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
- Expected time to train a model for 300 steps: ~20 minutes with T4
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
"""
)
with gr.Row():
with gr.Column():
gr.Markdown("Output Model")
output_model_name = gr.Text(label="Name of your model", placeholder="The surfer man", max_lines=1)
validation_prompt = gr.Text(
label="Validation Prompt", placeholder="prompt to test the model, e.g: a dog is surfing"
)
with gr.Column():
gr.Markdown("Upload Settings")
with gr.Row():
upload_to_hub = gr.Checkbox(label="Upload model to Hub", value=True)
use_private_repo = gr.Checkbox(label="Private", value=True)
delete_existing_repo = gr.Checkbox(label="Delete existing repo of the same name", value=False)
upload_to = gr.Radio(
label="Upload to",
choices=[_.value for _ in UploadTarget],
value=UploadTarget.MODEL_LIBRARY.value,
)
pause_space_after_training = gr.Checkbox(
label="Pause this Space after training",
value=False,
interactive=bool(os.getenv("SPACE_ID")),
visible=False,
)
run_button = gr.Button("Start Training", interactive=not disable_run_button)
with gr.Box():
gr.Text(label="Log", value=read_log, lines=10, max_lines=10, every=1)
if pipe is not None:
run_button.click(fn=pipe.clear)
run_button.click(
fn=trainer.run,
inputs=[
training_video,
training_prompt,
output_model_name,
delete_existing_repo,
validation_prompt,
base_model,
resolution,
num_training_steps,
learning_rate,
gradient_accumulation,
seed,
fp16,
use_8bit_adam,
checkpointing_steps,
validation_epochs,
upload_to_hub,
use_private_repo,
delete_existing_repo,
upload_to,
pause_space_after_training,
hf_token,
],
)
return demo
if __name__ == "__main__":
trainer = Trainer()
demo = create_training_demo(trainer)
demo.queue(api_open=False, max_size=1).launch()