Spaces:
Runtime error
Runtime error
File size: 2,500 Bytes
69a6cef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import logging
import os.path
import zipfile
from contextlib import contextmanager
from typing import ContextManager, Tuple, Optional, Union
from gchar.games import get_character
from gchar.games.base import Character
from hbutils.system import TemporaryDirectory, urlsplit
from huggingface_hub import hf_hub_url
from waifuc.utils import download_file
from ..utils import get_hf_fs, get_ch_name
@contextmanager
def load_dataset_for_character(source, size: Union[Tuple[int, int], str] = (512, 704)) \
-> ContextManager[Tuple[Optional[Character], str]]:
if isinstance(source, str) and os.path.exists(source):
if os.path.isdir(source):
logging.info(f'Dataset directory {source!r} loaded.')
yield None, source
elif os.path.isfile(source):
with zipfile.ZipFile(source, 'r') as zf, TemporaryDirectory() as td:
zf.extractall(td)
logging.info(f'Archive dataset {source!r} unzipped to {td!r} and loaded.')
yield None, td
else:
raise OSError(f'Unknown local source - {source!r}.')
else:
if isinstance(source, Character):
repo = f'AppleHarem/{get_ch_name(source)}'
else:
try_ch = get_character(source)
if try_ch is None:
repo = source
else:
source = try_ch
repo = f'AppleHarem/{get_ch_name(source)}'
hf_fs = get_hf_fs()
if isinstance(size, tuple):
width, height = size
ds_name = f'{width}x{height}'
elif isinstance(size, str):
ds_name = size
else:
raise TypeError(f'Unknown dataset type - {size!r}.')
if hf_fs.exists(f'datasets/{repo}/dataset-{ds_name}.zip'):
logging.info(f'Online dataset {repo!r} founded.')
zip_url = hf_hub_url(repo_id=repo, repo_type='dataset', filename=f'dataset-{ds_name}.zip')
with TemporaryDirectory() as dltmp:
zip_file = os.path.join(dltmp, 'dataset.zip')
download_file(zip_url, zip_file, desc=f'{repo}/{urlsplit(zip_url).filename}')
with zipfile.ZipFile(zip_file, 'r') as zf, TemporaryDirectory() as td:
zf.extractall(td)
logging.info(f'Online dataset {repo!r} loaded at {td!r}.')
yield source, td
else:
raise ValueError(f'Remote dataset {repo!r} not found for {source!r}.')
|