Spaces:
Runtime error
Runtime error
from functools import partial | |
from typing import Iterator, Union, List, Mapping, Literal | |
from PIL import Image | |
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags | |
from .base import ProcessAction, BaseAction | |
from ..model import ImageItem | |
def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False, | |
general_threshold: float = 0.5, character_threshold: float = 0.5, **kwargs): | |
_ = kwargs | |
_, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold) | |
return {**features, **characters} | |
def _wd14_tagging(image: Image.Image, model_name: str, | |
general_threshold: float = 0.35, character_threshold: float = 0.85, **kwargs): | |
_ = kwargs | |
_, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold) | |
return {**features, **characters} | |
def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.7, **kwargs): | |
_ = kwargs | |
features = get_mldanbooru_tags(image, use_real_name, general_threshold) | |
return features | |
_TAGGING_METHODS = { | |
'deepdanbooru': _deepdanbooru_tagging, | |
'wd14_vit': partial(_wd14_tagging, model_name='ViT'), | |
'wd14_convnext': partial(_wd14_tagging, model_name='ConvNext'), | |
'wd14_convnextv2': partial(_wd14_tagging, model_name='ConvNextV2'), | |
'wd14_swinv2': partial(_wd14_tagging, model_name='SwinV2'), | |
'mldanbooru': _mldanbooru_tagging, | |
} | |
TaggingMethodTyping = Literal[ | |
'deepdanbooru', 'wd14_vit', 'wd14_convnext', 'wd14_convnextv2', 'wd14_swinv2', 'mldanbooru'] | |
class TaggingAction(ProcessAction): | |
def __init__(self, method: TaggingMethodTyping = 'wd14_convnextv2', force: bool = False, **kwargs): | |
self.method = _TAGGING_METHODS[method] | |
self.force = force | |
self.kwargs = kwargs | |
def process(self, item: ImageItem) -> ImageItem: | |
if 'tags' in item.meta and not self.force: | |
return item | |
else: | |
tags = self.method(image=item.image, **self.kwargs) | |
return ImageItem(item.image, {**item.meta, 'tags': tags}) | |
class TagFilterAction(BaseAction): | |
def __init__(self, tags: Union[List[str], Mapping[str, float]], | |
method: TaggingMethodTyping = 'wd14_convnextv2', **kwargs): | |
if isinstance(tags, (list, tuple)): | |
self.tags = {tag: 1e-6 for tag in tags} | |
elif isinstance(tags, dict): | |
self.tags = dict(tags) | |
else: | |
raise TypeError(f'Unknown type of tags - {tags!r}.') | |
self.tagger = TaggingAction(method, force=False, **kwargs) | |
def iter(self, item: ImageItem) -> Iterator[ImageItem]: | |
item = self.tagger(item) | |
tags = item.meta['tags'] | |
valid = True | |
for tag, min_score in self.tags.items(): | |
if tags[tag] < min_score: | |
valid = False | |
break | |
if valid: | |
yield item | |
def reset(self): | |
self.tagger.reset() | |