LittleApple-fp16's picture
Upload 88 files
4f8ad24
import logging
from enum import IntEnum
from typing import Iterator, Optional, List, Tuple
import numpy as np
from hbutils.string import plural_word
from hbutils.testing import disable_output
from imgutils.metrics import ccip_extract_feature, ccip_default_threshold, ccip_clustering, ccip_batch_differences
from .base import BaseAction
from ..model import ImageItem
class CCIPStatus(IntEnum):
INIT = 0x1
APPROACH = 0x2
EVAL = 0x3
INIT_WITH_SOURCE = 0x4
class CCIPAction(BaseAction):
def __init__(self, init_source=None, *, min_val_count: int = 15, step: int = 5,
ratio_threshold: float = 0.6, min_clu_dump_ratio: float = 0.3, cmp_threshold: float = 0.5,
eps: Optional[float] = None, min_samples: Optional[int] = None,
model='ccip-caformer-24-randaug-pruned', threshold: Optional[float] = None):
self.init_source = init_source
self.min_val_count = min_val_count
self.step = step
self.ratio_threshold = ratio_threshold
self.min_clu_dump_ratio = min_clu_dump_ratio
self.cmp_threshold = cmp_threshold
self.eps, self.min_samples = eps, min_samples
self.model = model
self.threshold = threshold or ccip_default_threshold(self.model)
self.items = []
self.item_released = []
self.feats = []
if self.init_source is not None:
self.status = CCIPStatus.INIT_WITH_SOURCE
else:
self.status = CCIPStatus.INIT
def _extract_feature(self, item: ImageItem):
if 'ccip_feature' in item.meta:
return item.meta['ccip_feature']
else:
return ccip_extract_feature(item.image, model=self.model)
def _try_cluster(self) -> bool:
with disable_output():
clu_ids = ccip_clustering(self.feats, method='optics', model=self.model,
eps=self.eps, min_samples=self.min_samples)
clu_counts = {}
for id_ in clu_ids:
if id_ != -1:
clu_counts[id_] = clu_counts.get(id_, 0) + 1
clu_total = sum(clu_counts.values()) if clu_counts else 0
chosen_id = None
for id_, count in clu_counts.items():
if count >= clu_total * self.ratio_threshold:
chosen_id = id_
break
if chosen_id is not None:
feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id]
clu_dump_ratio = np.array([
self._compare_to_exists(feat, base_set=feats)
for feat in feats
]).astype(float).mean()
if clu_dump_ratio >= self.min_clu_dump_ratio:
self.items = [item for i, item in enumerate(self.items) if clu_ids[i] == chosen_id]
self.item_released = [False] * len(self.items)
self.feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id]
return True
else:
return False
else:
return False
def _compare_to_exists(self, feat, base_set=None) -> Tuple[bool, List[int]]:
diffs = ccip_batch_differences([feat, *(base_set or self.feats)], model=self.model)[0, 1:]
matches = diffs <= self.threshold
return matches.astype(float).mean() >= self.cmp_threshold
def _dump_items(self) -> Iterator[ImageItem]:
for i in range(len(self.items)):
if not self.item_released[i]:
if self._compare_to_exists(self.feats[i]):
self.item_released[i] = True
yield self.items[i]
def _eval_iter(self, item: ImageItem) -> Iterator[ImageItem]:
feat = self._extract_feature(item)
if self._compare_to_exists(feat):
self.feats.append(feat)
yield item
if (len(self.feats) - len(self.items)) % self.step == 0:
yield from self._dump_items()
def iter(self, item: ImageItem) -> Iterator[ImageItem]:
if self.status == CCIPStatus.INIT_WITH_SOURCE:
cnt = 0
logging.info('Existing anchor detected.')
for item_ in self.init_source:
self.feats.append(self._extract_feature(item_))
yield item_
cnt += 1
logging.info(f'{plural_word(cnt, "items")} loaded from anchor.')
self.status = CCIPStatus.EVAL
yield from self._eval_iter(item)
elif self.status == CCIPStatus.INIT:
self.items.append(item)
self.feats.append(self._extract_feature(item))
if len(self.items) >= self.min_val_count:
if self._try_cluster():
self.status = CCIPStatus.EVAL
yield from self._dump_items()
else:
self.status = CCIPStatus.APPROACH
elif self.status == CCIPStatus.APPROACH:
self.items.append(item)
self.feats.append(self._extract_feature(item))
if (len(self.items) - self.min_val_count) % self.step == 0:
if self._try_cluster():
self.status = CCIPStatus.EVAL
yield from self._dump_items()
elif self.status == CCIPStatus.EVAL:
yield from self._eval_iter(item)
else:
raise ValueError(f'Unknown status for {self.__class__.__name__} - {self.status!r}.')
def reset(self):
self.items.clear()
self.item_released.clear()
self.feats.clear()
if self.init_source:
self.status = CCIPStatus.INIT_WITH_SOURCE
else:
self.status = CCIPStatus.INIT