Spaces:
Runtime error
Runtime error
import logging | |
import os.path | |
import zipfile | |
from contextlib import contextmanager | |
from typing import ContextManager, Tuple, Optional, Union | |
from gchar.games import get_character | |
from gchar.games.base import Character | |
from hbutils.system import TemporaryDirectory, urlsplit | |
from huggingface_hub import hf_hub_url | |
from waifuc.utils import download_file | |
from ..utils import get_hf_fs, get_ch_name | |
def load_dataset_for_character(source, size: Union[Tuple[int, int], str] = (512, 704)) \ | |
-> ContextManager[Tuple[Optional[Character], str]]: | |
if isinstance(source, str) and os.path.exists(source): | |
if os.path.isdir(source): | |
logging.info(f'Dataset directory {source!r} loaded.') | |
yield None, source | |
elif os.path.isfile(source): | |
with zipfile.ZipFile(source, 'r') as zf, TemporaryDirectory() as td: | |
zf.extractall(td) | |
logging.info(f'Archive dataset {source!r} unzipped to {td!r} and loaded.') | |
yield None, td | |
else: | |
raise OSError(f'Unknown local source - {source!r}.') | |
else: | |
if isinstance(source, Character): | |
repo = f'AppleHarem/{get_ch_name(source)}' | |
else: | |
try_ch = get_character(source) | |
if try_ch is None: | |
repo = source | |
else: | |
source = try_ch | |
repo = f'AppleHarem/{get_ch_name(source)}' | |
hf_fs = get_hf_fs() | |
if isinstance(size, tuple): | |
width, height = size | |
ds_name = f'{width}x{height}' | |
elif isinstance(size, str): | |
ds_name = size | |
else: | |
raise TypeError(f'Unknown dataset type - {size!r}.') | |
if hf_fs.exists(f'datasets/{repo}/dataset-{ds_name}.zip'): | |
logging.info(f'Online dataset {repo!r} founded.') | |
zip_url = hf_hub_url(repo_id=repo, repo_type='dataset', filename=f'dataset-{ds_name}.zip') | |
with TemporaryDirectory() as dltmp: | |
zip_file = os.path.join(dltmp, 'dataset.zip') | |
download_file(zip_url, zip_file, desc=f'{repo}/{urlsplit(zip_url).filename}') | |
with zipfile.ZipFile(zip_file, 'r') as zf, TemporaryDirectory() as td: | |
zf.extractall(td) | |
logging.info(f'Online dataset {repo!r} loaded at {td!r}.') | |
yield source, td | |
else: | |
raise ValueError(f'Remote dataset {repo!r} not found for {source!r}.') | |