import json from dataclasses import asdict, dataclass from functools import lru_cache from os import PathLike from pathlib import Path from typing import Any, Optional import numpy as np import pandas as pd from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from PIL import Image class DictJsonMixin: def asdict(self, *args, **kwargs) -> dict[str, Any]: return asdict(self, *args, **kwargs) def asjson(self, *args, **kwargs): return json.dumps(asdict(self, *args, **kwargs)) @dataclass class LabelData(DictJsonMixin): names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] @dataclass class ImageLabels(DictJsonMixin): caption: str booru: str rating: dict[str, float] general: dict[str, float] character: dict[str, float] @lru_cache(maxsize=5) def load_labels(version: str = "v3", data_dir: PathLike = "./data") -> LabelData: data_dir = Path(data_dir).resolve() csv_path = data_dir.joinpath(f"selected_tags_{version}.csv") if not csv_path.is_file(): raise FileNotFoundError(f"{csv_path.name} not found in {data_dir}") df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data @lru_cache(maxsize=5) def load_labels_hf( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> LabelData: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token ) csv_path = Path(csv_path).resolve() except HfHubHTTPError as e: raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data def pil_ensure_rgb(image: Image.Image) -> Image.Image: # convert to RGB/RGBA if not already (deals with palette images etc.) if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") # convert RGBA to RGB with white background if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square( image: Image.Image, fill: tuple[int, int, int] = (255, 255, 255), ) -> Image.Image: w, h = image.size # get the largest dimension so we can pad to a square px = max(image.size) # pad to square with white background canvas = Image.new("RGB", (px, px), fill) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas def preprocess_image( image: Image.Image, size_px: int | tuple[int, int], upscale: bool = True, ) -> Image.Image: """ Preprocess an image to be square and centered on a white background. """ if isinstance(size_px, int): size_px = (size_px, size_px) # ensure RGB and pad to square image = pil_ensure_rgb(image) image = pil_pad_square(image) # resize to target size if image.size[0] < size_px[0] or image.size[1] < size_px[1]: if upscale is False: raise ValueError("Image is smaller than target size, and upscaling is disabled") image = image.resize(size_px, Image.LANCZOS) if image.size[0] > size_px[0] or image.size[1] > size_px[1]: image.thumbnail(size_px, Image.BICUBIC) return image # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368 kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ]