import json from functools import lru_cache import numpy as np import pandas as pd from PIL import Image from autofaiss import build_index from hfutils.operate import get_hf_fs from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.metrics import ccip_batch_extract_features, ccip_batch_differences, ccip_default_threshold SRC_REPO = 'deepghs/character_index' hf_fs = get_hf_fs() @lru_cache() def _make_index(): tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json'))) embeddings = np.load(hf_hub_download( repo_id=SRC_REPO, repo_type='dataset', filename='index/embeddings.npy', )) index, index_infos = build_index(embeddings, save_on_disk=False) return (index, index_infos), tag_infos def gender_predict(p): if p['boy'] - p['girl'] >= 0.1: return 'male' elif p['girl'] - p['boy'] >= 0.1: return 'female' else: return 'not_sure' def query_character(image: Image.Image, count: int = 5, order_by: str = 'same_ratio', threshold: float = 0.7): (index, index_infos), tag_infos = _make_index() query = ccip_batch_extract_features([image]) assert query.shape == (1, 768) query = query / np.linalg.norm(query) all_dists, all_indices = index.search(query, k=count) dists, indices = all_dists[0], all_indices[0] images, records = {}, [] for dist, idx in zip(dists, indices): info = tag_infos[idx] current_image = load_image(hf_hub_download( repo_id=SRC_REPO, repo_type='dataset', filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp' )) feats = np.load(hf_hub_download( repo_id=SRC_REPO, repo_type='dataset', filename=f'{info["hprefix"]}/{info["short_tag"]}/feat.npy' )) diffs = ccip_batch_differences([query[0], *feats])[0, 1:] images[info['tag']] = current_image records.append({ 'id': info['id'], 'tag': info['tag'], 'gender': gender_predict(info['gender']), 'copyright': info['copyright'], 'index_score': dist, 'mean_diff': diffs.mean(), 'same_ratio': (diffs < ccip_default_threshold()).mean(), }) df_records = pd.DataFrame(records) df_records = df_records.sort_values( by=[order_by, 'index_score'] if order_by != 'index_score' else ['index_score'], ascending=[False, False] if order_by != 'index_score' else [False], ) df_records = df_records[df_records[order_by] >= threshold] ret_images = [] for row_item in df_records.to_dict('records'): ret_images.append((images[row_item['tag']], f'{row_item["tag"]} ({row_item[order_by]:.3f})')) return ret_images, df_records