Spaces:
Runtime error
Runtime error
File size: 5,268 Bytes
69a6cef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import glob
import json
import logging
import math
import os.path
from typing import Optional, Union
from gchar.games.base import Character
from hbutils.string import plural_word
from hbutils.system import TemporaryDirectory
from hcpdiff.train_ac import Trainer
from hcpdiff.train_ac_single import TrainerSingleCard
from hcpdiff.utils import load_config_with_cli
from .embedding import create_embedding, _DEFAULT_TRAIN_MODEL
from ..dataset import load_dataset_for_character, save_recommended_tags
from ..utils import data_to_cli_args, get_ch_name
_DEFAULT_TRAIN_CFG = 'cfgs/train/examples/lora_anime_character.yaml'
def _min_training_steps(dataset_size: int, unit: int = 20):
steps = 4000.9 + (720.9319 - 4000.9) / (1 + (dataset_size / 297.2281) ** 0.6543184)
return int(round(steps / unit)) * unit
def train_plora(
source: Union[str, Character], name: Optional[str] = None,
epochs: int = 13, min_steps: Optional[int] = None,
save_for_times: int = 15, no_min_steps: bool = False,
batch_size: int = 4, pretrained_model: str = _DEFAULT_TRAIN_MODEL,
workdir: str = None, emb_n_words: int = 4, emb_init_text: str = '*[0.017, 1]',
unet_rank: float = 8, text_encoder_rank: float = 4,
cfg_file: str = _DEFAULT_TRAIN_CFG, single_card: bool = True,
dataset_type: str = 'stage3-1200', use_ratio: bool = True,
):
with load_dataset_for_character(source, dataset_type) as (ch, ds_dir):
if ch is None:
if name is None:
raise ValueError(f'Name should be specified when using custom source - {source!r}.')
else:
name = name or get_ch_name(ch)
dataset_size = len(glob.glob(os.path.join(ds_dir, '*.png')))
logging.info(f'{plural_word(dataset_size, "image")} found in dataset.')
actual_steps = epochs * dataset_size
if not no_min_steps:
actual_steps = max(actual_steps, _min_training_steps(dataset_size, 20))
if min_steps is not None:
actual_steps = max(actual_steps, min_steps)
save_per_steps = max(int(math.ceil(actual_steps / save_for_times / 20) * 20), 20)
steps = int(math.ceil(actual_steps / save_per_steps) * save_per_steps)
epochs = int(math.ceil(steps / dataset_size))
logging.info(f'Training for {plural_word(steps, "step")}, {plural_word(epochs, "epoch")}, '
f'save per {plural_word(save_per_steps, "step")} ...')
workdir = workdir or os.path.join('runs', name)
os.makedirs(workdir, exist_ok=True)
# os.makedirs(workdir)
save_recommended_tags(ds_dir, name, workdir)
with open(os.path.join(workdir, 'meta.json'), 'w', encoding='utf-8') as f:
json.dump({
'dataset': {
'size': dataset_size,
'type': dataset_type,
},
}, f, indent=4, sort_keys=True, ensure_ascii=False)
with TemporaryDirectory() as embs_dir:
logging.info(f'Creating embeddings {name!r} at {embs_dir!r}, '
f'n_words: {emb_n_words!r}, init_text: {emb_init_text!r}, '
f'pretrained_model: {pretrained_model!r}.')
create_embedding(
name, emb_n_words, emb_init_text,
replace=True,
pretrained_model=pretrained_model,
embs_dir=embs_dir,
)
cli_args = data_to_cli_args({
'train': {
'train_steps': steps,
'save_step': save_per_steps,
'scheduler': {
'num_training_steps': steps,
}
},
'model': {
'pretrained_model_name_or_path': pretrained_model,
},
'character_name': name,
'dataset_dir': ds_dir,
'exp_dir': workdir,
'unet_rank': unet_rank,
'text_encoder_rank': text_encoder_rank,
'tokenizer_pt': {
'emb_dir': embs_dir,
},
'data': {
'dataset1': {
'batch_size': batch_size,
'bucket': {
'_target_': 'hcpdiff.data.bucket.RatioBucket.from_files',
'target_area': '${times:512,512}',
'num_bucket': 5,
} if use_ratio else {
'_target_': 'hcpdiff.data.bucket.SizeBucket.from_files',
'target_area': '---',
'num_bucket': 1,
}
},
},
})
conf = load_config_with_cli(cfg_file, args_list=cli_args) # skip --cfg
logging.info(f'Training with {cfg_file!r}, args: {cli_args!r} ...')
if single_card:
logging.info('Training with single card ...')
trainer = TrainerSingleCard(conf)
else:
logging.info('Training with non-single cards ...')
trainer = Trainer(conf)
trainer.train()
|