import json import logging import os from typing import Optional from hbutils.system import TemporaryDirectory from huggingface_hub import hf_hub_url from tqdm.auto import tqdm from .draw import _DEFAULT_INFER_MODEL, draw_with_workdir from ..dataset import save_recommended_tags from ..utils import get_hf_fs, download_file def draw_to_directory(workdir: str, export_dir: str, step: int, n_repeats: int = 2, pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, image_width: int = 512, image_height: int = 768, infer_steps: int = 30, lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', model_hash: Optional[str] = None): from ..publish.export import KNOWN_MODEL_HASHES model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model) os.makedirs(export_dir, exist_ok=True) while True: try: drawings = draw_with_workdir( workdir, model_steps=step, n_repeats=n_repeats, pretrained_model=pretrained_model, width=image_width, height=image_height, infer_steps=infer_steps, lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method, model_hash=model_hash, ) except RuntimeError: n_repeats += 1 else: break all_image_files = [] for draw in drawings: img_file = os.path.join(export_dir, f'{draw.name}.png') draw.image.save(img_file, pnginfo=draw.pnginfo) all_image_files.append(img_file) with open(os.path.join(export_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f: print(draw.preview_info, file=f) def draw_with_repo(repository: str, export_dir: str, step: Optional[int] = None, n_repeats: int = 2, pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, image_width: int = 512, image_height: int = 768, infer_steps: int = 30, lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', model_hash: Optional[str] = None): from ..publish import find_steps_in_workdir hf_fs = get_hf_fs() if not hf_fs.exists(f'{repository}/meta.json'): raise ValueError(f'Invalid repository or no model found - {repository!r}.') logging.info(f'Model repository {repository!r} found.') meta = json.loads(hf_fs.read_text(f'{repository}/meta.json')) step = step or meta['best_step'] logging.info(f'Using step {step} ...') with TemporaryDirectory() as workdir: logging.info('Downloading models ...') for f in tqdm(hf_fs.glob(f'{repository}/{step}/raw/*')): rel_file = os.path.relpath(f, repository) local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file)) if os.path.dirname(local_file): os.makedirs(os.path.dirname(local_file), exist_ok=True) download_file( hf_hub_url(repository, filename=rel_file), local_file ) logging.info(f'Regenerating tags for {workdir!r} ...') pt_name, _ = find_steps_in_workdir(workdir) game_name = pt_name.split('_')[-1] name = '_'.join(pt_name.split('_')[:-1]) from gchar.games.dispatch.access import GAME_CHARS if game_name in GAME_CHARS: ch_cls = GAME_CHARS[game_name] ch = ch_cls.get(name) else: ch = None if ch is None: source = repository else: source = ch logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.') save_recommended_tags(source, name=pt_name, workdir=workdir, ds_size=meta["dataset"]['type']) logging.info('Drawing ...') draw_to_directory( workdir, export_dir, step, n_repeats, pretrained_model, clip_skip, image_width, image_height, infer_steps, lora_alpha, sample_method, model_hash )