Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
from __future__ import annotations | |
import os | |
import gradio as gr | |
from constants import UploadTarget | |
from inference import InferencePipeline | |
from trainer import Trainer | |
def create_training_demo(trainer: Trainer, | |
pipe: InferencePipeline | None = None, | |
disable_run_button: bool = False) -> gr.Blocks: | |
def read_log() -> str: | |
with open(trainer.log_file) as f: | |
lines = f.readlines() | |
return ''.join(lines[-10:]) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Box(): | |
gr.Markdown('Training Data') | |
training_video = gr.File(label='Training video') | |
training_prompt = gr.Textbox( | |
label='Training prompt', | |
max_lines=1, | |
placeholder='A man is surfing') | |
gr.Markdown(''' | |
- Upload a video and write a `Training Prompt` that describes the video. | |
''') | |
with gr.Column(): | |
with gr.Box(): | |
gr.Markdown('Training Parameters') | |
with gr.Row(): | |
base_model = gr.Text( | |
label='Base Model', | |
value='CompVis/stable-diffusion-v1-4', | |
max_lines=1) | |
resolution = gr.Dropdown(choices=['512', '768'], | |
value='512', | |
label='Resolution', | |
visible=False) | |
hf_token = gr.Text(label='Hugging Face Write Token', | |
type='password', | |
visible=os.getenv('HF_TOKEN') is None) | |
with gr.Accordion(label='Advanced options', open=False): | |
num_training_steps = gr.Number( | |
label='Number of Training Steps', | |
value=300, | |
precision=0) | |
learning_rate = gr.Number(label='Learning Rate', | |
value=0.000035) | |
gradient_accumulation = gr.Number( | |
label='Number of Gradient Accumulation', | |
value=1, | |
precision=0) | |
seed = gr.Slider(label='Seed', | |
minimum=0, | |
maximum=100000, | |
step=1, | |
randomize=True, | |
value=0) | |
fp16 = gr.Checkbox(label='FP16', value=True) | |
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', | |
value=False) | |
checkpointing_steps = gr.Number( | |
label='Checkpointing Steps', | |
value=1000, | |
precision=0) | |
validation_epochs = gr.Number( | |
label='Validation Epochs', value=100, precision=0) | |
gr.Markdown(''' | |
- The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library. | |
- Expected time to train a model for 300 steps: ~20 minutes with T4 | |
- You can check the training status by pressing the "Open logs" button if you are running this on your Space. | |
''') | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown('Output Model') | |
output_model_name = gr.Text(label='Name of your model', | |
placeholder='The surfer man', | |
max_lines=1) | |
validation_prompt = gr.Text( | |
label='Validation Prompt', | |
placeholder= | |
'prompt to test the model, e.g: a dog is surfing') | |
with gr.Column(): | |
gr.Markdown('Upload Settings') | |
with gr.Row(): | |
upload_to_hub = gr.Checkbox(label='Upload model to Hub', | |
value=True) | |
use_private_repo = gr.Checkbox(label='Private', value=True) | |
delete_existing_repo = gr.Checkbox( | |
label='Delete existing repo of the same name', | |
value=False) | |
upload_to = gr.Radio( | |
label='Upload to', | |
choices=[_.value for _ in UploadTarget], | |
value=UploadTarget.MODEL_LIBRARY.value) | |
pause_space_after_training = gr.Checkbox( | |
label='Pause this Space after training', | |
value=False, | |
interactive=bool(os.getenv('SPACE_ID')), | |
visible=False) | |
run_button = gr.Button('Start Training', | |
interactive=not disable_run_button) | |
with gr.Box(): | |
gr.Text(label='Log', | |
value=read_log, | |
lines=10, | |
max_lines=10, | |
every=1) | |
if pipe is not None: | |
run_button.click(fn=pipe.clear) | |
run_button.click(fn=trainer.run, | |
inputs=[ | |
training_video, | |
training_prompt, | |
output_model_name, | |
delete_existing_repo, | |
validation_prompt, | |
base_model, | |
resolution, | |
num_training_steps, | |
learning_rate, | |
gradient_accumulation, | |
seed, | |
fp16, | |
use_8bit_adam, | |
checkpointing_steps, | |
validation_epochs, | |
upload_to_hub, | |
use_private_repo, | |
delete_existing_repo, | |
upload_to, | |
pause_space_after_training, | |
hf_token, | |
]) | |
return demo | |
if __name__ == '__main__': | |
trainer = Trainer() | |
demo = create_training_demo(trainer) | |
demo.queue(api_open=False, max_size=1).launch() | |