import logging from enum import IntEnum from typing import Iterator, Optional, List, Tuple import numpy as np from hbutils.string import plural_word from hbutils.testing import disable_output from imgutils.metrics import ccip_extract_feature, ccip_default_threshold, ccip_clustering, ccip_batch_differences from .base import BaseAction from ..model import ImageItem class CCIPStatus(IntEnum): INIT = 0x1 APPROACH = 0x2 EVAL = 0x3 INIT_WITH_SOURCE = 0x4 class CCIPAction(BaseAction): def __init__(self, init_source=None, *, min_val_count: int = 15, step: int = 5, ratio_threshold: float = 0.6, min_clu_dump_ratio: float = 0.3, cmp_threshold: float = 0.5, eps: Optional[float] = None, min_samples: Optional[int] = None, model='ccip-caformer-24-randaug-pruned', threshold: Optional[float] = None): self.init_source = init_source self.min_val_count = min_val_count self.step = step self.ratio_threshold = ratio_threshold self.min_clu_dump_ratio = min_clu_dump_ratio self.cmp_threshold = cmp_threshold self.eps, self.min_samples = eps, min_samples self.model = model self.threshold = threshold or ccip_default_threshold(self.model) self.items = [] self.item_released = [] self.feats = [] if self.init_source is not None: self.status = CCIPStatus.INIT_WITH_SOURCE else: self.status = CCIPStatus.INIT def _extract_feature(self, item: ImageItem): if 'ccip_feature' in item.meta: return item.meta['ccip_feature'] else: return ccip_extract_feature(item.image, model=self.model) def _try_cluster(self) -> bool: with disable_output(): clu_ids = ccip_clustering(self.feats, method='optics', model=self.model, eps=self.eps, min_samples=self.min_samples) clu_counts = {} for id_ in clu_ids: if id_ != -1: clu_counts[id_] = clu_counts.get(id_, 0) + 1 clu_total = sum(clu_counts.values()) if clu_counts else 0 chosen_id = None for id_, count in clu_counts.items(): if count >= clu_total * self.ratio_threshold: chosen_id = id_ break if chosen_id is not None: feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id] clu_dump_ratio = np.array([ self._compare_to_exists(feat, base_set=feats) for feat in feats ]).astype(float).mean() if clu_dump_ratio >= self.min_clu_dump_ratio: self.items = [item for i, item in enumerate(self.items) if clu_ids[i] == chosen_id] self.item_released = [False] * len(self.items) self.feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id] return True else: return False else: return False def _compare_to_exists(self, feat, base_set=None) -> Tuple[bool, List[int]]: diffs = ccip_batch_differences([feat, *(base_set or self.feats)], model=self.model)[0, 1:] matches = diffs <= self.threshold return matches.astype(float).mean() >= self.cmp_threshold def _dump_items(self) -> Iterator[ImageItem]: for i in range(len(self.items)): if not self.item_released[i]: if self._compare_to_exists(self.feats[i]): self.item_released[i] = True yield self.items[i] def _eval_iter(self, item: ImageItem) -> Iterator[ImageItem]: feat = self._extract_feature(item) if self._compare_to_exists(feat): self.feats.append(feat) yield item if (len(self.feats) - len(self.items)) % self.step == 0: yield from self._dump_items() def iter(self, item: ImageItem) -> Iterator[ImageItem]: if self.status == CCIPStatus.INIT_WITH_SOURCE: cnt = 0 logging.info('Existing anchor detected.') for item_ in self.init_source: self.feats.append(self._extract_feature(item_)) yield item_ cnt += 1 logging.info(f'{plural_word(cnt, "items")} loaded from anchor.') self.status = CCIPStatus.EVAL yield from self._eval_iter(item) elif self.status == CCIPStatus.INIT: self.items.append(item) self.feats.append(self._extract_feature(item)) if len(self.items) >= self.min_val_count: if self._try_cluster(): self.status = CCIPStatus.EVAL yield from self._dump_items() else: self.status = CCIPStatus.APPROACH elif self.status == CCIPStatus.APPROACH: self.items.append(item) self.feats.append(self._extract_feature(item)) if (len(self.items) - self.min_val_count) % self.step == 0: if self._try_cluster(): self.status = CCIPStatus.EVAL yield from self._dump_items() elif self.status == CCIPStatus.EVAL: yield from self._eval_iter(item) else: raise ValueError(f'Unknown status for {self.__class__.__name__} - {self.status!r}.') def reset(self): self.items.clear() self.item_released.clear() self.feats.clear() if self.init_source: self.status = CCIPStatus.INIT_WITH_SOURCE else: self.status = CCIPStatus.INIT