k4d3 commited on
Commit
8c1e4ca
·
2 Parent(s): 026c9c4 1442347

Merge branch 'main' of hf.co:/k4d3/toolkit

Browse files
__pycache__/e6db_reader.cpython-312.pyc ADDED
Binary file (16.5 kB). View file
 
data/tag2idx.json.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b6d0566323e99297d88d9cb6d8f7403e0f5eebc65670a71f303753a97f9786b
3
+ size 3840505
data/tags.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc537f3afe6ae8c152670ba7aa871989286d9139d7be8b1c40cb53ea36cafe0f
3
+ size 2630619
data/tags_categories.bin.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3ba9a95809e680f40a15a41d57a507630eca5f87ca8ee3518ed56aff662413e
3
+ size 109543
demo.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from e6db_reader import TagNormalizer, tag_categories, tag_category2id
2
+
3
+ tn = TagNormalizer('data')
4
+ tn.map_inputs(lambda tag, tid: tag.replace('_', ' '))
5
+
6
+ for tag in ['pokemon', 'pikachu', 'charizard', 'loona']:
7
+ print(tag, tn.get_category(tag))
e6db_reader.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Python only utils (no dependencies)"
2
+ import gzip
3
+ import json
4
+ import logging
5
+ import math
6
+ import warnings
7
+ from pathlib import Path
8
+ from typing import Callable, Iterable
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ tag_categories = [
13
+ "general",
14
+ "artist",
15
+ None, # Invalid catid
16
+ "copyright",
17
+ "character",
18
+ "species",
19
+ "invalid",
20
+ "meta",
21
+ "lore",
22
+ "pool",
23
+ ]
24
+ tag_category2id = {v: k for k, v in enumerate(tag_categories) if v}
25
+ tag_categories_colors = [
26
+ "#b4c7d9",
27
+ "#f2ac08",
28
+ None, # Invalid catid
29
+ "#d0d",
30
+ "#0a0",
31
+ "#ed5d1f",
32
+ "#ff3d3d",
33
+ "#fff",
34
+ "#282",
35
+ "wheat",
36
+ ]
37
+ tag_categories_alt_colors = [
38
+ "#2e76b4",
39
+ "#fbd67f",
40
+ None, # Invalid catid
41
+ "#ff5eff",
42
+ "#2bff2b",
43
+ "#f6b295",
44
+ "#ffbdbd",
45
+ "#666",
46
+ "#5fdb5f",
47
+ "#d0b27a",
48
+ ]
49
+
50
+
51
+ def load_tags(data_dir):
52
+ """
53
+ Load tag data, returns a tuple `(tag2idx, idx2tag, tag_categories)`
54
+
55
+ * `tag2idx`: dict mapping tag and aliases to numerical ids
56
+ * `idx2tag`: list mapping numerical id to tag string
57
+ * `tag_categories`: byte string mapping numerical id to categories
58
+ """
59
+ data_dir = Path(data_dir)
60
+ with gzip.open(data_dir / "tags.txt.gz", "rt", encoding="utf-8") as fd:
61
+ idx2tag = fd.read().split("\n")
62
+ if not idx2tag[-1]:
63
+ idx2tag = idx2tag[:-1]
64
+ with gzip.open(data_dir / "tag2idx.json.gz", "rb") as fp:
65
+ tag2idx = json.load(fp)
66
+ with gzip.open(data_dir / "tags_categories.bin.gz", "rb") as fp:
67
+ tag_categories = fp.read()
68
+ logging.info(f"Loaded {len(idx2tag)} tags, {len(tag2idx)} tag2id mappings")
69
+ return tag2idx, idx2tag, tag_categories
70
+
71
+
72
+ def load_implications(data_dir):
73
+ """
74
+ Load implication mappings. Returns a tuple `(implications, implications_rej)`
75
+
76
+ * `implications`: dict mapping numerical ids to a list of implied numerical
77
+ ids. Contains transitive implications.
78
+ * `implications_rej`: dict mapping tag strings to a list of implied
79
+ numerical ids. keys in implications_rej are tags that have a very little
80
+ usage (less than 2 posts) and don't have numerical ids associated with
81
+ them.
82
+ """
83
+ with gzip.open(data_dir / "implications.json.gz", "rb") as fp:
84
+ implications = json.load(fp)
85
+ implications = {int(k): v for k, v in implications.items()}
86
+ with gzip.open(data_dir / "implications_rej.json.gz", "rb") as fp:
87
+ implications_rej = json.load(fp)
88
+ logger.info(
89
+ f"Loaded {len(implications)} implications + {len(implications_rej)} implication from tags without id"
90
+ )
91
+ return implications, implications_rej
92
+
93
+
94
+ def tag_rank_to_freq(rank: int) -> float:
95
+ """Approximate the frequency of a tag given its rank"""
96
+ return math.exp(26.4284 * math.tanh(2.93505 * rank ** (-0.136501)) - 11.492)
97
+
98
+
99
+ def tag_freq_to_rank(freq: int) -> float:
100
+ """Approximate the rank of a tag given its frequency"""
101
+ log_freq = math.log(freq)
102
+ return math.exp(
103
+ -7.57186
104
+ * (0.0465456 * log_freq - 1.24326)
105
+ * math.log(1.13045 - 0.0720383 * log_freq)
106
+ + 12.1903
107
+ )
108
+
109
+
110
+ InMapFun = Callable[[str, int | None], list[str]]
111
+ OutMapFun = Callable[[str], list[str]]
112
+
113
+
114
+ class TagNormalizer:
115
+ """
116
+ Map tag strings to numerical ids, and vice versa.
117
+
118
+ Multiple strings can be mapped to a single id, while each id maps to a
119
+ single string. As a result, the encode/decode process can be used to
120
+ normalize tags to canonical spelling.
121
+
122
+ See `add_input_mappings` for adding aliases, and `rename_output` for setting
123
+ the canonical spelling of a tag.
124
+ """
125
+
126
+ def __init__(self, path_or_data: str | Path | tuple[dict, list, bytes]):
127
+ if isinstance(path_or_data, (Path, str)):
128
+ data = load_tags(path_or_data)
129
+ else:
130
+ data = path_or_data
131
+ self.tag2idx, self.idx2tag, self.tag_categories = data
132
+
133
+ def get_category(self, tag: int | str, as_string=True) -> int:
134
+ if isinstance(tag, str):
135
+ tag = self.encode(tag)
136
+ cat = self.tag_categories[tag]
137
+ if as_string:
138
+ return tag_categories[cat]
139
+ return cat
140
+
141
+ def encode(self, tag: str, default=None):
142
+ "Convert tag string to numerical id"
143
+ return self.tag2idx.get(tag, default)
144
+
145
+ def decode(self, tag: int | str):
146
+ "Convert numerical id to tag string"
147
+ if isinstance(tag, str):
148
+ return tag
149
+ return self.idx2tag[tag]
150
+
151
+ def get_reverse_mapping(self):
152
+ """Return a list mapping id -> [ tag strings ]"""
153
+ res = [[] for i in range(len(self.idx2tag))]
154
+ for tag, tid in self.tag2idx.items():
155
+ res[tid].append(tag)
156
+ return res
157
+
158
+ def add_input_mappings(
159
+ self, tags: str | Iterable[str], to_tid: int | str, on_conflict="raise"
160
+ ):
161
+ """Associate tag strings to an id for recognition by `encode`
162
+
163
+ `on_conflict` defines what to do when the tag string is already mapped
164
+ to a different id:
165
+
166
+ * "raise": raise an ValueError (default)
167
+ * "warn": raise a warning
168
+ * "overwrite_rarest": make the tag point to the most frequently used tid
169
+ * "overwrite": silently overwrite the mapping
170
+ * "silent", or any other string: don't set the mapping
171
+ """
172
+ tag2idx = self.tag2idx
173
+ if not isinstance(to_tid, int):
174
+ to_tid = tag2idx[to_tid]
175
+ if isinstance(tags, str):
176
+ tags = (tags,)
177
+ for tag in tags:
178
+ conflict = tag2idx.get(tag, to_tid)
179
+ if conflict != to_tid:
180
+ msg = f"mapping {tag!r}->{self.idx2tag[to_tid]!r}({to_tid}) conflicts with previous mapping {tag!r}->{self.idx2tag[conflict]!r}({conflict})."
181
+ if on_conflict == "raise":
182
+ raise ValueError(msg)
183
+ elif on_conflict == "warn":
184
+ logger.warning(msg)
185
+ elif on_conflict == "overwrite_rarest" and to_tid > conflict:
186
+ continue
187
+ elif on_conflict != "overwrite":
188
+ continue
189
+ tag2idx[tag] = to_tid
190
+
191
+ def remove_input_mappings(self, tags: str | Iterable[str]):
192
+ """Remove tag strings from the mapping"""
193
+ if isinstance(tags, str):
194
+ tags = (tags,)
195
+ for tag in tags:
196
+ if tag in self.tag2idx:
197
+ del self.tag2idx[tag]
198
+ else:
199
+ logger.warning(f"tag {tag!r} is not a valid tag")
200
+
201
+ def rename_output(self, orig: int | str, dest: str):
202
+ """Change the tag string associated with an id. Used by `decode`."""
203
+ if not isinstance(orig, int):
204
+ orig = self.tag2idx[orig]
205
+ self.idx2tag[orig] = dest
206
+
207
+ def map_inputs(
208
+ self, mapfun: InMapFun, prepopulate=True, on_conflict="raise"
209
+ ) -> "TagNormalizer":
210
+ tag2idx = self.tag2idx.copy() if prepopulate else {}
211
+ res = type(self)((tag2idx, self.idx2tag, self.tag_categories))
212
+ for tag, tid in self.tag2idx.items():
213
+ res.add_input_mappings(mapfun(tag, tid), tid, on_conflict=on_conflict)
214
+ return res
215
+
216
+ def map_outputs(self, mapfun: OutMapFun) -> "TagNormalizer":
217
+ idx2tag = [mapfun(t, i) for i, t in enumerate(self.idx2tag)]
218
+ return type(self)((self.tag2idx, idx2tag, self.tag_categories))
219
+
220
+ def get(self, key: int | str, default=None):
221
+ """
222
+ Returns the string tag associated with a numerical id, or conversely,
223
+ the id associated with a tag.
224
+ """
225
+ if isinstance(key, int):
226
+ idx2tag = self.idx2tag
227
+ if key >= len(idx2tag):
228
+ return default
229
+ return idx2tag[key]
230
+ return self.tag2idx.get(key, default)
231
+
232
+
233
+ class TagSetNormalizer:
234
+ def __init__(self, path_or_data: str | Path | tuple[TagNormalizer, dict, dict]):
235
+ if isinstance(path_or_data, (Path, str)):
236
+ data = TagNormalizer(path_or_data), *load_implications(path_or_data)
237
+ else:
238
+ data = path_or_data
239
+ self.tag_normalizer, self.implications, self.implications_rej = data
240
+
241
+ def map_inputs(self, mapfun: InMapFun, on_conflict="raise") -> "TagSetNormalizer":
242
+ tag_normalizer = self.tag_normalizer.map_inputs(mapfun, on_conflict=on_conflict)
243
+
244
+ implications_rej: dict[str, list[str]] = {}
245
+ for tag_string, implied_ids in self.implications_rej.items():
246
+ for new_tag_string in mapfun(tag_string, None):
247
+ conflict = implications_rej.get(new_tag_string, implied_ids)
248
+ if conflict != implied_ids:
249
+ msg = f"mapping {tag_string!r}->{implied_ids} conflicts with previous mapping {tag_string!r}->{conflict}."
250
+ if on_conflict == "raise":
251
+ raise ValueError(msg)
252
+ elif on_conflict == "warn":
253
+ warnings.warn(msg)
254
+ elif on_conflict != "overwrite":
255
+ continue
256
+ implications_rej[new_tag_string] = implied_ids
257
+
258
+ res = type(self)((tag_normalizer, self.implications, implications_rej))
259
+ return res
260
+
261
+ def map_outputs(self, mapfun: OutMapFun) -> "TagSetNormalizer":
262
+ tag_normalizer = self.tag_normalizer.map_outputs(mapfun)
263
+ return type(self)((tag_normalizer, self.implications, self.implications_rej))
264
+
265
+ def get_implied(self, tag: int | str) -> list[int]:
266
+ if isinstance(tag, int):
267
+ return self.implications.get(tag, ())
268
+ else:
269
+ return self.implications_rej.get(tag, ())
270
+
271
+ def encode(
272
+ self,
273
+ tags: list[str],
274
+ keep_implied: bool | set[int] = False,
275
+ max_antecedent_rank: int | None = None,
276
+ drop_antecedent_rank: int | None = None,
277
+ ) -> tuple[list[int | str], set[int]]:
278
+ """
279
+ Encode a list of string as numerical ids and strip implied tags.
280
+
281
+ Unknown tags are returned as strings.
282
+
283
+ Returns :
284
+
285
+ * a list of tag ids and unknown tag strings,
286
+ * a list of implied tag ids.
287
+ """
288
+ tag2idx = self.tag_normalizer.tag2idx
289
+ N = len(tag2idx)
290
+ max_antecedent_rank = max_antecedent_rank or N + 1
291
+ drop_antecedent_rank = drop_antecedent_rank or N + 1
292
+ get_implied = self.implications.get
293
+ get_implied_rej = self.implications_rej.get
294
+
295
+ stack = [tag2idx.get(tag, tag) for tag in tags[::-1]]
296
+ implied = set()
297
+ res = dict() # dict as a cheap ordered set
298
+ while stack:
299
+ tag = stack.pop()
300
+ if isinstance(tag, int):
301
+ antecedent_rank = tag
302
+ consequents = get_implied(tag)
303
+ else:
304
+ # the tag might be a very rare antecedent (less than two posts)
305
+ # that doesn't have a tag id
306
+ antecedent_rank = N
307
+ consequents = get_implied_rej(tag)
308
+ if consequents:
309
+ if antecedent_rank < max_antecedent_rank:
310
+ implied.update(consequents)
311
+ else:
312
+ # The implied tags from low frequency antecedent (high rank)
313
+ # are added to the list and instead the antecedent may be
314
+ # dropped
315
+ stack.extend(consequents)
316
+ if antecedent_rank >= drop_antecedent_rank:
317
+ continue
318
+ res[tag] = None
319
+ res = res.keys()
320
+
321
+ if not keep_implied:
322
+ res = [t for t in res if t not in implied]
323
+ elif isinstance(keep_implied, set):
324
+ res = [t for t in res if t not in implied or t in keep_implied]
325
+ else:
326
+ res = list(res)
327
+ return res, implied
328
+
329
+ def decode(self, tags: Iterable[int | str]) -> list[str]:
330
+ idx2tag = self.tag_normalizer.idx2tag
331
+ return [idx2tag[t] if isinstance(t, int) else t for t in tags]
joy CHANGED
@@ -32,6 +32,7 @@ from transformers import (
32
  PreTrainedTokenizerFast,
33
  )
34
  from torch import nn
 
35
 
36
  CLIP_PATH = "google/siglip-so400m-patch14-384"
37
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
@@ -80,6 +81,8 @@ CAPTION_TYPE_MAP = {
80
 
81
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
82
 
 
 
83
  class ImageAdapter(nn.Module):
84
  """
85
  Custom image adapter module for processing CLIP vision outputs.
@@ -466,6 +469,9 @@ def main():
466
  if args.random_tags is not None and args.feed_from_tags is None:
467
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
468
 
 
 
 
469
  # Initialize and load models
470
  joy_caption_model = JoyCaptionModel()
471
  joy_caption_model.load_models()
 
32
  PreTrainedTokenizerFast,
33
  )
34
  from torch import nn
35
+ from e6db_reader import TagNormalizer
36
 
37
  CLIP_PATH = "google/siglip-so400m-patch14-384"
38
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
 
81
 
82
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
83
 
84
+ E6DB_DATA = Path(__file__).resolve().parent / "data"
85
+
86
  class ImageAdapter(nn.Module):
87
  """
88
  Custom image adapter module for processing CLIP vision outputs.
 
469
  if args.random_tags is not None and args.feed_from_tags is None:
470
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
471
 
472
+ print('Loading e621 tag data')
473
+ tag_normalizer = TagNormalizer(E6DB_DATA)
474
+
475
  # Initialize and load models
476
  joy_caption_model = JoyCaptionModel()
477
  joy_caption_model.load_models()