File size: 2,686 Bytes
69a6cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os.path
from functools import partial

import click
from ditk import logging
from gchar.generic import import_generic
from gchar.utils import GLOBAL_CONTEXT_SETTINGS
from gchar.utils import print_version as _origin_print_version
from huggingface_hub import hf_hub_url
from tqdm.auto import tqdm

from cyberharem.dataset import save_recommended_tags
from cyberharem.publish import find_steps_in_workdir
from ..utils import get_hf_fs, download_file

print_version = partial(_origin_print_version, 'cyberharem.train')

import_generic()


@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Publish trained models')
@click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True)
def cli():
    pass  # pragma: no cover


@cli.command('download', context_settings={**GLOBAL_CONTEXT_SETTINGS}, help='Download trained ckpts from huggingface.')
@click.option('-r', '--repository', 'repository', type=str, required=True,
              help='Repository.', show_default=True)
@click.option('-w', '--workdir', 'workdir', type=str, default=None,
              help='Work directory', show_default=True)
@click.option('--no-tags', 'no_tags', is_flag=True, type=bool, default=False,
              help='Do not generate tags.', show_default=True)
def download(repository, workdir, no_tags):
    logging.try_init_root(logging.INFO)
    workdir = workdir or os.path.join('runs', repository.split('/')[-1])

    logging.info(f'Downloading models for {workdir!r} ...')
    hf_fs = get_hf_fs()
    for f in tqdm(hf_fs.glob(f'{repository}/*/raw/*')):
        rel_file = os.path.relpath(f, repository)
        local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file))
        if os.path.dirname(local_file):
            os.makedirs(os.path.dirname(local_file), exist_ok=True)
        download_file(
            hf_hub_url(repository, filename=rel_file),
            local_file
        )

    if not no_tags:
        logging.info(f'Regenerating tags for {workdir!r} ...')
        pt_name, _ = find_steps_in_workdir(workdir)
        game_name = pt_name.split('_')[-1]
        name = '_'.join(pt_name.split('_')[:-1])

        from gchar.games.dispatch.access import GAME_CHARS
        if game_name in GAME_CHARS:
            ch_cls = GAME_CHARS[game_name]
            ch = ch_cls.get(name)
        else:
            ch = None

        if ch is None:
            source = repository
        else:
            source = ch

        logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.')
        save_recommended_tags(source, name=pt_name, workdir=workdir)
        logging.info('Success!')


if __name__ == '__main__':
    cli()