LittleApple-fp16's picture
Upload 88 files
4f8ad24
from functools import partial
from typing import Iterator, Union, List, Mapping, Literal
from PIL import Image
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags
from .base import ProcessAction, BaseAction
from ..model import ImageItem
def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False,
general_threshold: float = 0.5, character_threshold: float = 0.5, **kwargs):
_ = kwargs
_, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold)
return {**features, **characters}
def _wd14_tagging(image: Image.Image, model_name: str,
general_threshold: float = 0.35, character_threshold: float = 0.85, **kwargs):
_ = kwargs
_, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold)
return {**features, **characters}
def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.7, **kwargs):
_ = kwargs
features = get_mldanbooru_tags(image, use_real_name, general_threshold)
return features
_TAGGING_METHODS = {
'deepdanbooru': _deepdanbooru_tagging,
'wd14_vit': partial(_wd14_tagging, model_name='ViT'),
'wd14_convnext': partial(_wd14_tagging, model_name='ConvNext'),
'wd14_convnextv2': partial(_wd14_tagging, model_name='ConvNextV2'),
'wd14_swinv2': partial(_wd14_tagging, model_name='SwinV2'),
'mldanbooru': _mldanbooru_tagging,
}
TaggingMethodTyping = Literal[
'deepdanbooru', 'wd14_vit', 'wd14_convnext', 'wd14_convnextv2', 'wd14_swinv2', 'mldanbooru']
class TaggingAction(ProcessAction):
def __init__(self, method: TaggingMethodTyping = 'wd14_convnextv2', force: bool = False, **kwargs):
self.method = _TAGGING_METHODS[method]
self.force = force
self.kwargs = kwargs
def process(self, item: ImageItem) -> ImageItem:
if 'tags' in item.meta and not self.force:
return item
else:
tags = self.method(image=item.image, **self.kwargs)
return ImageItem(item.image, {**item.meta, 'tags': tags})
class TagFilterAction(BaseAction):
def __init__(self, tags: Union[List[str], Mapping[str, float]],
method: TaggingMethodTyping = 'wd14_convnextv2', **kwargs):
if isinstance(tags, (list, tuple)):
self.tags = {tag: 1e-6 for tag in tags}
elif isinstance(tags, dict):
self.tags = dict(tags)
else:
raise TypeError(f'Unknown type of tags - {tags!r}.')
self.tagger = TaggingAction(method, force=False, **kwargs)
def iter(self, item: ImageItem) -> Iterator[ImageItem]:
item = self.tagger(item)
tags = item.meta['tags']
valid = True
for tag, min_score in self.tags.items():
if tags[tag] < min_score:
valid = False
break
if valid:
yield item
def reset(self):
self.tagger.reset()