Spaces:
Runtime error
Runtime error
import os.path | |
from functools import partial | |
import click | |
from ditk import logging | |
from gchar.generic import import_generic | |
from gchar.utils import GLOBAL_CONTEXT_SETTINGS | |
from gchar.utils import print_version as _origin_print_version | |
from huggingface_hub import hf_hub_url | |
from tqdm.auto import tqdm | |
from cyberharem.dataset import save_recommended_tags | |
from cyberharem.publish import find_steps_in_workdir | |
from ..utils import get_hf_fs, download_file | |
print_version = partial(_origin_print_version, 'cyberharem.train') | |
import_generic() | |
def cli(): | |
pass # pragma: no cover | |
def download(repository, workdir, no_tags): | |
logging.try_init_root(logging.INFO) | |
workdir = workdir or os.path.join('runs', repository.split('/')[-1]) | |
logging.info(f'Downloading models for {workdir!r} ...') | |
hf_fs = get_hf_fs() | |
for f in tqdm(hf_fs.glob(f'{repository}/*/raw/*')): | |
rel_file = os.path.relpath(f, repository) | |
local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file)) | |
if os.path.dirname(local_file): | |
os.makedirs(os.path.dirname(local_file), exist_ok=True) | |
download_file( | |
hf_hub_url(repository, filename=rel_file), | |
local_file | |
) | |
if not no_tags: | |
logging.info(f'Regenerating tags for {workdir!r} ...') | |
pt_name, _ = find_steps_in_workdir(workdir) | |
game_name = pt_name.split('_')[-1] | |
name = '_'.join(pt_name.split('_')[:-1]) | |
from gchar.games.dispatch.access import GAME_CHARS | |
if game_name in GAME_CHARS: | |
ch_cls = GAME_CHARS[game_name] | |
ch = ch_cls.get(name) | |
else: | |
ch = None | |
if ch is None: | |
source = repository | |
else: | |
source = ch | |
logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.') | |
save_recommended_tags(source, name=pt_name, workdir=workdir) | |
logging.info('Success!') | |
if __name__ == '__main__': | |
cli() | |