LittleApple_fp16
upload
69a6cef
raw
history blame
2.69 kB
import os.path
from functools import partial
import click
from ditk import logging
from gchar.generic import import_generic
from gchar.utils import GLOBAL_CONTEXT_SETTINGS
from gchar.utils import print_version as _origin_print_version
from huggingface_hub import hf_hub_url
from tqdm.auto import tqdm
from cyberharem.dataset import save_recommended_tags
from cyberharem.publish import find_steps_in_workdir
from ..utils import get_hf_fs, download_file
print_version = partial(_origin_print_version, 'cyberharem.train')
import_generic()
@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models')
@click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True)
def cli():
pass # pragma: no cover
@cli.command('download', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Download trained ckpts from huggingface.')
@click.option('-r', '--repository', 'repository', type=str, required=True,
help='Repository.', show_default=True)
@click.option('-w', '--workdir', 'workdir', type=str, default=None,
help='Work directory', show_default=True)
@click.option('--no-tags', 'no_tags', is_flag=True, type=bool, default=False,
help='Do not generate tags.', show_default=True)
def download(repository, workdir, no_tags):
logging.try_init_root(logging.INFO)
workdir = workdir or os.path.join('runs', repository.split('/')[-1])
logging.info(f'Downloading models for {workdir!r} ...')
hf_fs = get_hf_fs()
for f in tqdm(hf_fs.glob(f'{repository}/*/raw/*')):
rel_file = os.path.relpath(f, repository)
local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file))
if os.path.dirname(local_file):
os.makedirs(os.path.dirname(local_file), exist_ok=True)
download_file(
hf_hub_url(repository, filename=rel_file),
local_file
)
if not no_tags:
logging.info(f'Regenerating tags for {workdir!r} ...')
pt_name, _ = find_steps_in_workdir(workdir)
game_name = pt_name.split('_')[-1]
name = '_'.join(pt_name.split('_')[:-1])
from gchar.games.dispatch.access import GAME_CHARS
if game_name in GAME_CHARS:
ch_cls = GAME_CHARS[game_name]
ch = ch_cls.get(name)
else:
ch = None
if ch is None:
source = repository
else:
source = ch
logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.')
save_recommended_tags(source, name=pt_name, workdir=workdir)
logging.info('Success!')
if __name__ == '__main__':
cli()