#!/usr/bin/env python from dataclasses import dataclass from pathlib import Path from typing import Optional import numpy as np import pandas as pd import timm import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from PIL import Image import pillow_jxl # type: ignore from simple_parsing import field, parse_known_args from timm.data import create_transform, resolve_data_config from torch import Tensor, nn from torch.nn import functional as F import gettext import locale from multiprocessing import Pool, cpu_count from itertools import islice import multiprocessing # Set start method to spawn for CUDA compatibility multiprocessing.set_start_method('spawn', force=True) torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_REPO_MAP = { "vit": "SmilingWolf/wd-vit-tagger-v3", "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", "convnext": "SmilingWolf/wd-convnext-tagger-v3", } def setup_i18n(): """Set up internationalization""" try: current_locale = locale.getlocale()[0] if current_locale is None: # Fallback if no locale is set current_locale = 'en_US' locale_path = Path(__file__).parent / 'locales' trans = gettext.translation('wdv3', locale_path, languages=[current_locale]) trans.install() return trans.gettext except: return gettext.gettext # Initialize translation _ = setup_i18n() def pil_ensure_rgb(image: Image.Image) -> Image.Image: # convert to RGB/RGBA if not already (deals with palette images etc.) if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") # convert RGBA to RGB with white background if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square(image: Image.Image) -> Image.Image: w, h = image.size # get the largest dimension so we can pad to a square px = max(image.size) # pad to square with white background canvas = Image.new("RGB", (px, px), (255, 255, 255)) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas @dataclass class LabelData: names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] def load_labels_hf( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> LabelData: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token ) csv_path = Path(csv_path).resolve() except HfHubHTTPError as e: raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data def get_tags( probs: Tensor, labels: LabelData, gen_threshold: float, char_threshold: float, ): # Convert indices+probs to labels probs = list(zip(labels.names, probs.numpy())) # First 4 labels are actually ratings rating_labels = dict([probs[i] for i in labels.rating]) # General labels, pick any where prediction confidence > threshold gen_labels = [probs[i] for i in labels.general] gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) # Character labels, pick any where prediction confidence > threshold char_labels = [probs[i] for i in labels.character] char_labels = dict([x for x in char_labels if x[1] > char_threshold]) char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) # Combine general and character labels, sort by confidence combined_names = [x for x in gen_labels] combined_names.extend([x for x in char_labels]) # Convert to a string suitable for use as a training caption caption = ", ".join(combined_names) taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)") return caption, taglist, rating_labels, char_labels, gen_labels @dataclass class ScriptOptions: image_file: Path = field(positional=True, help=_("Image file or directory to process")) model: str = field(default="vit", help=_("Model architecture to use (vit, swinv2, convnext)")) gen_threshold: float = field(default=0.35, help=_("General threshold for tagging")) char_threshold: float = field(default=0.75, help=_("Character threshold for tagging")) recursive: bool = field(default=False, help=_("Process subdirectories recursively")) def process_image(image_path: Path, model: nn.Module, labels: LabelData, transform, opts: ScriptOptions) -> None: # Skip if text file already exists output_path = image_path.with_suffix('.wd') if output_path.exists(): print(f"Skipping {image_path.name} - caption file already exists") return print(f"Processing {image_path.name}...") # get image img_input: Image.Image = Image.open(image_path) # ensure image is RGB img_input = pil_ensure_rgb(img_input) # pad to square with white background img_input = pil_pad_square(img_input) # run the model's input transform to convert to tensor and rescale inputs: Tensor = transform(img_input).unsqueeze(0) # NCHW image RGB to BGR inputs = inputs[:, [2, 1, 0]] with torch.inference_mode(): # move model to GPU, if available if torch_device.type != "cpu": model = model.to(torch_device) inputs = inputs.to(torch_device) # run the model outputs = model.forward(inputs) # apply the final activation function (timm doesn't support doing this internally) outputs = F.sigmoid(outputs) # move inputs, outputs, and model back to to cpu if we were on GPU if torch_device.type != "cpu": inputs = inputs.to("cpu") outputs = outputs.to("cpu") model = model.to("cpu") print("Processing results...") caption, taglist, ratings, character, general = get_tags( probs=outputs.squeeze(0), labels=labels, gen_threshold=opts.gen_threshold, char_threshold=opts.char_threshold, ) # Save tags to a text file with the same base name as the input image with open(output_path, 'w', encoding='utf-8') as f: f.write(caption.replace('_', ' ')) def create_model_resources(model_name): repo_id = MODEL_REPO_MAP.get(model_name) print(f"Loading model '{model_name}' from '{repo_id}'...") model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval() state_dict = timm.models.load_state_dict_from_hf(repo_id) model.load_state_dict(state_dict) print("Loading tag list...") labels: LabelData = load_labels_hf(repo_id=repo_id) print("Creating data transform...") transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) return model, labels, transform def process_batch(args): batch_paths, model_name, gen_threshold, char_threshold = args # Create model resources within the process model, labels, transform = create_model_resources(model_name) # Create a minimal options object for process_image class MinimalOpts: def __init__(self, gen_threshold, char_threshold): self.gen_threshold = gen_threshold self.char_threshold = char_threshold opts = MinimalOpts(gen_threshold, char_threshold) for image_path in batch_paths: try: process_image(image_path, model, labels, transform, opts) except Exception as e: print(f"Error processing {image_path.name}: {e}") def batch_iterator(iterable, batch_size): iterator = iter(iterable) while batch := list(islice(iterator, batch_size)): yield batch def main(opts: ScriptOptions): target_path = Path(opts.image_file).resolve() # Handle directory processing if target_path.is_dir(): image_extensions = {'.png', '.jpg', '.jpeg', '.webp', '.jxl'} pattern = '**/*' if opts.recursive else '*' # Collect all valid image paths image_paths = [ path for path in target_path.glob(pattern) if path.is_file() and path.suffix.lower() in image_extensions ] # Create batches of 16 images batches = list(batch_iterator(image_paths, 16)) # Prepare arguments for multiprocessing num_processes = min(cpu_count(), len(batches)) process_args = [ (batch, opts.model, opts.gen_threshold, opts.char_threshold) for batch in batches ] print(f"Processing {len(image_paths)} images using {num_processes} processes...") # Process batches in parallel with Pool(processes=num_processes) as pool: pool.map(process_batch, process_args) else: # Process single image file if not target_path.is_file(): raise FileNotFoundError(f"Image file not found: {target_path}") # Create model resources for single file processing model, labels, transform = create_model_resources(opts.model) process_image(target_path, model, labels, transform, opts) if __name__ == "__main__": opts, _ = parse_known_args(ScriptOptions) if opts.model not in MODEL_REPO_MAP: print(f"Available models: {list(MODEL_REPO_MAP.keys())}") raise ValueError(f"Unknown model name '{opts.model}'") main(opts)