kohya_ss / library /class_basic_training.py
Ateras's picture
Upload folder using huggingface_hub
fe6327d
raw
history blame
No virus
4.47 kB
import gradio as gr
import os
class BasicTraining:
def __init__(
self,
learning_rate_value='1e-6',
lr_scheduler_value='constant',
lr_warmup_value='0',
finetuning: bool = False,
):
self.learning_rate_value = learning_rate_value
self.lr_scheduler_value = lr_scheduler_value
self.lr_warmup_value = lr_warmup_value
self.finetuning = finetuning
with gr.Row():
self.train_batch_size = gr.Slider(
minimum=1,
maximum=64,
label='Train batch size',
value=1,
step=1,
)
self.epoch = gr.Number(label='Epoch', value=1, precision=0)
self.save_every_n_epochs = gr.Number(
label='Save every N epochs', value=1, precision=0
)
self.caption_extension = gr.Textbox(
label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption',
)
with gr.Row():
self.mixed_precision = gr.Dropdown(
label='Mixed precision',
choices=[
'no',
'fp16',
'bf16',
],
value='fp16',
)
self.save_precision = gr.Dropdown(
label='Save precision',
choices=[
'float',
'fp16',
'bf16',
],
value='fp16',
)
self.num_cpu_threads_per_process = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label='Number of CPU threads per core',
value=2,
)
self.seed = gr.Textbox(
label='Seed', placeholder='(Optional) eg:1234'
)
self.cache_latents = gr.Checkbox(label='Cache latents', value=True)
self.cache_latents_to_disk = gr.Checkbox(
label='Cache latents to disk', value=False
)
with gr.Row():
self.learning_rate = gr.Number(
label='Learning rate', value=learning_rate_value
)
self.lr_scheduler = gr.Dropdown(
label='LR Scheduler',
choices=[
'adafactor',
'constant',
'constant_with_warmup',
'cosine',
'cosine_with_restarts',
'linear',
'polynomial',
],
value=lr_scheduler_value,
)
self.lr_warmup = gr.Slider(
label='LR warmup (% of steps)',
value=lr_warmup_value,
minimum=0,
maximum=100,
step=1,
)
self.optimizer = gr.Dropdown(
label='Optimizer',
choices=[
'AdamW',
'AdamW8bit',
'Adafactor',
'DAdaptation',
'DAdaptAdaGrad',
'DAdaptAdam',
'DAdaptAdan',
'DAdaptAdanIP',
'DAdaptAdamPreprint',
'DAdaptLion',
'DAdaptSGD',
'Lion',
'Lion8bit',
'Prodigy',
'SGDNesterov',
'SGDNesterov8bit',
],
value='AdamW8bit',
interactive=True,
)
with gr.Row():
self.optimizer_args = gr.Textbox(
label='Optimizer extra arguments',
placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True',
)
with gr.Row(visible=not finetuning):
self.max_resolution = gr.Textbox(
label='Max resolution',
value='512,512',
placeholder='512,512',
)
self.stop_text_encoder_training = gr.Slider(
minimum=-1,
maximum=100,
value=0,
step=1,
label='Stop text encoder training',
)
self.enable_bucket = gr.Checkbox(label='Enable buckets', value=True)