|
"Python only utils (no dependencies)" |
|
import gzip |
|
import json |
|
import logging |
|
import math |
|
import warnings |
|
from pathlib import Path |
|
from typing import Callable, Iterable |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
tag_categories = [ |
|
"general", |
|
"artist", |
|
None, |
|
"copyright", |
|
"character", |
|
"species", |
|
"invalid", |
|
"meta", |
|
"lore", |
|
"pool", |
|
] |
|
tag_category2id = {v: k for k, v in enumerate(tag_categories) if v} |
|
tag_categories_colors = [ |
|
"#b4c7d9", |
|
"#f2ac08", |
|
None, |
|
"#d0d", |
|
"#0a0", |
|
"#ed5d1f", |
|
"#ff3d3d", |
|
"#fff", |
|
"#282", |
|
"wheat", |
|
] |
|
tag_categories_alt_colors = [ |
|
"#2e76b4", |
|
"#fbd67f", |
|
None, |
|
"#ff5eff", |
|
"#2bff2b", |
|
"#f6b295", |
|
"#ffbdbd", |
|
"#666", |
|
"#5fdb5f", |
|
"#d0b27a", |
|
] |
|
|
|
|
|
def load_tags(data_dir): |
|
""" |
|
Load tag data, returns a tuple `(tag2idx, idx2tag, tag_categories)` |
|
|
|
* `tag2idx`: dict mapping tag and aliases to numerical ids |
|
* `idx2tag`: list mapping numerical id to tag string |
|
* `tag_categories`: byte string mapping numerical id to categories |
|
""" |
|
data_dir = Path(data_dir) |
|
with gzip.open(data_dir / "tags.txt.gz", "rt", encoding="utf-8") as fd: |
|
idx2tag = fd.read().split("\n") |
|
if not idx2tag[-1]: |
|
idx2tag = idx2tag[:-1] |
|
with gzip.open(data_dir / "tag2idx.json.gz", "rb") as fp: |
|
tag2idx = json.load(fp) |
|
with gzip.open(data_dir / "tags_categories.bin.gz", "rb") as fp: |
|
tag_categories = fp.read() |
|
logging.info(f"Loaded {len(idx2tag)} tags, {len(tag2idx)} tag2id mappings") |
|
return tag2idx, idx2tag, tag_categories |
|
|
|
|
|
def load_implications(data_dir): |
|
""" |
|
Load implication mappings. Returns a tuple `(implications, implications_rej)` |
|
|
|
* `implications`: dict mapping numerical ids to a list of implied numerical |
|
ids. Contains transitive implications. |
|
* `implications_rej`: dict mapping tag strings to a list of implied |
|
numerical ids. keys in implications_rej are tags that have a very little |
|
usage (less than 2 posts) and don't have numerical ids associated with |
|
them. |
|
""" |
|
with gzip.open(data_dir / "implications.json.gz", "rb") as fp: |
|
implications = json.load(fp) |
|
implications = {int(k): v for k, v in implications.items()} |
|
with gzip.open(data_dir / "implications_rej.json.gz", "rb") as fp: |
|
implications_rej = json.load(fp) |
|
logger.info( |
|
f"Loaded {len(implications)} implications + {len(implications_rej)} implication from tags without id" |
|
) |
|
return implications, implications_rej |
|
|
|
|
|
def tag_rank_to_freq(rank: int) -> float: |
|
"""Approximate the frequency of a tag given its rank""" |
|
return math.exp(26.4284 * math.tanh(2.93505 * rank ** (-0.136501)) - 11.492) |
|
|
|
|
|
def tag_freq_to_rank(freq: int) -> float: |
|
"""Approximate the rank of a tag given its frequency""" |
|
log_freq = math.log(freq) |
|
return math.exp( |
|
-7.57186 |
|
* (0.0465456 * log_freq - 1.24326) |
|
* math.log(1.13045 - 0.0720383 * log_freq) |
|
+ 12.1903 |
|
) |
|
|
|
|
|
InMapFun = Callable[[str, int | None], list[str]] |
|
OutMapFun = Callable[[str], list[str]] |
|
|
|
|
|
class TagNormalizer: |
|
""" |
|
Map tag strings to numerical ids, and vice versa. |
|
|
|
Multiple strings can be mapped to a single id, while each id maps to a |
|
single string. As a result, the encode/decode process can be used to |
|
normalize tags to canonical spelling. |
|
|
|
See `add_input_mappings` for adding aliases, and `rename_output` for setting |
|
the canonical spelling of a tag. |
|
""" |
|
|
|
def __init__(self, path_or_data: str | Path | tuple[dict, list, bytes]): |
|
if isinstance(path_or_data, (Path, str)): |
|
data = load_tags(path_or_data) |
|
else: |
|
data = path_or_data |
|
self.tag2idx, self.idx2tag, self.tag_categories = data |
|
|
|
def get_category(self, tag: int | str, as_string=True) -> int: |
|
if isinstance(tag, str): |
|
tag = self.encode(tag) |
|
cat = self.tag_categories[tag] |
|
if as_string: |
|
return tag_categories[cat] |
|
return cat |
|
|
|
def encode(self, tag: str, default=None): |
|
"Convert tag string to numerical id" |
|
return self.tag2idx.get(tag, default) |
|
|
|
def decode(self, tag: int | str): |
|
"Convert numerical id to tag string" |
|
if isinstance(tag, str): |
|
return tag |
|
return self.idx2tag[tag] |
|
|
|
def get_reverse_mapping(self): |
|
"""Return a list mapping id -> [ tag strings ]""" |
|
res = [[] for i in range(len(self.idx2tag))] |
|
for tag, tid in self.tag2idx.items(): |
|
res[tid].append(tag) |
|
return res |
|
|
|
def add_input_mappings( |
|
self, tags: str | Iterable[str], to_tid: int | str, on_conflict="raise" |
|
): |
|
"""Associate tag strings to an id for recognition by `encode` |
|
|
|
`on_conflict` defines what to do when the tag string is already mapped |
|
to a different id: |
|
|
|
* "raise": raise an ValueError (default) |
|
* "warn": raise a warning |
|
* "overwrite_rarest": make the tag point to the most frequently used tid |
|
* "overwrite": silently overwrite the mapping |
|
* "silent", or any other string: don't set the mapping |
|
""" |
|
tag2idx = self.tag2idx |
|
if not isinstance(to_tid, int): |
|
to_tid = tag2idx[to_tid] |
|
if isinstance(tags, str): |
|
tags = (tags,) |
|
for tag in tags: |
|
conflict = tag2idx.get(tag, to_tid) |
|
if conflict != to_tid: |
|
msg = f"mapping {tag!r}->{self.idx2tag[to_tid]!r}({to_tid}) conflicts with previous mapping {tag!r}->{self.idx2tag[conflict]!r}({conflict})." |
|
if on_conflict == "raise": |
|
raise ValueError(msg) |
|
elif on_conflict == "warn": |
|
logger.warning(msg) |
|
elif on_conflict == "overwrite_rarest" and to_tid > conflict: |
|
continue |
|
elif on_conflict != "overwrite": |
|
continue |
|
tag2idx[tag] = to_tid |
|
|
|
def remove_input_mappings(self, tags: str | Iterable[str]): |
|
"""Remove tag strings from the mapping""" |
|
if isinstance(tags, str): |
|
tags = (tags,) |
|
for tag in tags: |
|
if tag in self.tag2idx: |
|
del self.tag2idx[tag] |
|
else: |
|
logger.warning(f"tag {tag!r} is not a valid tag") |
|
|
|
def rename_output(self, orig: int | str, dest: str): |
|
"""Change the tag string associated with an id. Used by `decode`.""" |
|
if not isinstance(orig, int): |
|
orig = self.tag2idx[orig] |
|
self.idx2tag[orig] = dest |
|
|
|
def map_inputs( |
|
self, mapfun: InMapFun, prepopulate=True, on_conflict="raise" |
|
) -> "TagNormalizer": |
|
tag2idx = self.tag2idx.copy() if prepopulate else {} |
|
res = type(self)((tag2idx, self.idx2tag, self.tag_categories)) |
|
for tag, tid in self.tag2idx.items(): |
|
res.add_input_mappings(mapfun(tag, tid), tid, on_conflict=on_conflict) |
|
return res |
|
|
|
def map_outputs(self, mapfun: OutMapFun) -> "TagNormalizer": |
|
idx2tag = [mapfun(t, i) for i, t in enumerate(self.idx2tag)] |
|
return type(self)((self.tag2idx, idx2tag, self.tag_categories)) |
|
|
|
def get(self, key: int | str, default=None): |
|
""" |
|
Returns the string tag associated with a numerical id, or conversely, |
|
the id associated with a tag. |
|
""" |
|
if isinstance(key, int): |
|
idx2tag = self.idx2tag |
|
if key >= len(idx2tag): |
|
return default |
|
return idx2tag[key] |
|
return self.tag2idx.get(key, default) |
|
|
|
|
|
class TagSetNormalizer: |
|
def __init__(self, path_or_data: str | Path | tuple[TagNormalizer, dict, dict]): |
|
if isinstance(path_or_data, (Path, str)): |
|
data = TagNormalizer(path_or_data), *load_implications(path_or_data) |
|
else: |
|
data = path_or_data |
|
self.tag_normalizer, self.implications, self.implications_rej = data |
|
|
|
def map_inputs(self, mapfun: InMapFun, on_conflict="raise") -> "TagSetNormalizer": |
|
tag_normalizer = self.tag_normalizer.map_inputs(mapfun, on_conflict=on_conflict) |
|
|
|
implications_rej: dict[str, list[str]] = {} |
|
for tag_string, implied_ids in self.implications_rej.items(): |
|
for new_tag_string in mapfun(tag_string, None): |
|
conflict = implications_rej.get(new_tag_string, implied_ids) |
|
if conflict != implied_ids: |
|
msg = f"mapping {tag_string!r}->{implied_ids} conflicts with previous mapping {tag_string!r}->{conflict}." |
|
if on_conflict == "raise": |
|
raise ValueError(msg) |
|
elif on_conflict == "warn": |
|
warnings.warn(msg) |
|
elif on_conflict != "overwrite": |
|
continue |
|
implications_rej[new_tag_string] = implied_ids |
|
|
|
res = type(self)((tag_normalizer, self.implications, implications_rej)) |
|
return res |
|
|
|
def map_outputs(self, mapfun: OutMapFun) -> "TagSetNormalizer": |
|
tag_normalizer = self.tag_normalizer.map_outputs(mapfun) |
|
return type(self)((tag_normalizer, self.implications, self.implications_rej)) |
|
|
|
def get_implied(self, tag: int | str) -> list[int]: |
|
if isinstance(tag, int): |
|
return self.implications.get(tag, ()) |
|
else: |
|
return self.implications_rej.get(tag, ()) |
|
|
|
def encode( |
|
self, |
|
tags: list[str], |
|
keep_implied: bool | set[int] = False, |
|
max_antecedent_rank: int | None = None, |
|
drop_antecedent_rank: int | None = None, |
|
) -> tuple[list[int | str], set[int]]: |
|
""" |
|
Encode a list of string as numerical ids and strip implied tags. |
|
|
|
Unknown tags are returned as strings. |
|
|
|
Returns : |
|
|
|
* a list of tag ids and unknown tag strings, |
|
* a list of implied tag ids. |
|
""" |
|
tag2idx = self.tag_normalizer.tag2idx |
|
N = len(tag2idx) |
|
max_antecedent_rank = max_antecedent_rank or N + 1 |
|
drop_antecedent_rank = drop_antecedent_rank or N + 1 |
|
get_implied = self.implications.get |
|
get_implied_rej = self.implications_rej.get |
|
|
|
stack = [tag2idx.get(tag, tag) for tag in tags[::-1]] |
|
implied = set() |
|
res = dict() |
|
while stack: |
|
tag = stack.pop() |
|
if isinstance(tag, int): |
|
antecedent_rank = tag |
|
consequents = get_implied(tag) |
|
else: |
|
|
|
|
|
antecedent_rank = N |
|
consequents = get_implied_rej(tag) |
|
if consequents: |
|
if antecedent_rank < max_antecedent_rank: |
|
implied.update(consequents) |
|
else: |
|
|
|
|
|
|
|
stack.extend(consequents) |
|
if antecedent_rank >= drop_antecedent_rank: |
|
continue |
|
res[tag] = None |
|
res = res.keys() |
|
|
|
if not keep_implied: |
|
res = [t for t in res if t not in implied] |
|
elif isinstance(keep_implied, set): |
|
res = [t for t in res if t not in implied or t in keep_implied] |
|
else: |
|
res = list(res) |
|
return res, implied |
|
|
|
def decode(self, tags: Iterable[int | str]) -> list[str]: |
|
idx2tag = self.tag_normalizer.idx2tag |
|
return [idx2tag[t] if isinstance(t, int) else t for t in tags] |
|
|