k4d3 commited on
Commit
3821594
1 Parent(s): a7f9c9a

joy: use tag categories for building prompts

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+
__pycache__/e6db_reader.cpython-312.pyc DELETED
Binary file (16.5 kB)
 
data/implications.json.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6240e9f23bacc42bcfccdfa3e86a439a4b0a489ee142c9ac855f12027c0657e7
3
+ size 228337
data/implications_rej.json.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33c9a2d1ed8f60d6f6122a6b2a816043de2ad029ca0fc2ad39f1fbd705e6beaa
3
+ size 94416
demo.py DELETED
@@ -1,7 +0,0 @@
1
- from e6db_reader import TagNormalizer, tag_categories, tag_category2id
2
-
3
- tn = TagNormalizer('data')
4
- tn.map_inputs(lambda tag, tid: tag.replace('_', ' '))
5
-
6
- for tag in ['pokemon', 'pikachu', 'charizard', 'loona']:
7
- print(tag, tn.get_category(tag))
 
 
 
 
 
 
 
 
e6db_reader.py CHANGED
@@ -93,7 +93,7 @@ def load_implications(data_dir):
93
 
94
  def tag_rank_to_freq(rank: int) -> float:
95
  """Approximate the frequency of a tag given its rank"""
96
- return math.exp(26.4284 * math.tanh(2.93505 * rank ** (-0.136501)) - 11.492)
97
 
98
 
99
  def tag_freq_to_rank(freq: int) -> float:
 
93
 
94
  def tag_rank_to_freq(rank: int) -> float:
95
  """Approximate the frequency of a tag given its rank"""
96
+ return math.exp(26.4284 * math.tanh(2.93505 * max(1, rank) ** (-0.136501)) - 11.492)
97
 
98
 
99
  def tag_freq_to_rank(freq: int) -> float:
joy CHANGED
@@ -18,6 +18,7 @@ import os
18
  import argparse
19
  import re
20
  import random
 
21
  from pathlib import Path
22
  from PIL import Image
23
  import pillow_jxl
@@ -32,7 +33,7 @@ from transformers import (
32
  PreTrainedTokenizerFast,
33
  )
34
  from torch import nn
35
- from e6db_reader import TagNormalizer
36
 
37
  CLIP_PATH = "google/siglip-so400m-patch14-384"
38
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
@@ -81,8 +82,6 @@ CAPTION_TYPE_MAP = {
81
 
82
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
83
 
84
- E6DB_DATA = Path(__file__).resolve().parent / "data"
85
-
86
  class ImageAdapter(nn.Module):
87
  """
88
  Custom image adapter module for processing CLIP vision outputs.
@@ -275,13 +274,13 @@ class JoyCaptionModel:
275
  caption_type: str,
276
  caption_tone: str,
277
  caption_length: str | int,
278
- custom_prompt: str = None) -> str:
279
  """
280
  Process an input image and generate a caption based on specified parameters.
281
  """
282
  torch.cuda.empty_cache()
283
 
284
- if caption_type == "custom" and custom_prompt:
285
  prompt_str = custom_prompt
286
  else:
287
  prompt_str = self._get_prompt_string(caption_type, caption_tone, caption_length)
@@ -470,7 +469,7 @@ def main():
470
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
471
 
472
  print('Loading e621 tag data')
473
- tag_normalizer = TagNormalizer(E6DB_DATA)
474
 
475
  # Initialize and load models
476
  joy_caption_model = JoyCaptionModel()
@@ -495,54 +494,22 @@ def main():
495
  input_image = Image.open(image_path).convert("RGB")
496
 
497
  # Use custom prompt if specified
 
498
  if args.caption_type == "custom":
499
- caption = joy_caption_model.process_image(
500
- input_image,
501
- "custom",
502
- args.caption_tone,
503
- args.caption_length,
504
- custom_prompt=args.custom_prompt
505
- )
506
- else:
507
- # Check for --feed-from-tags
508
- if args.feed_from_tags is not None:
509
- tag_file = find_tag_file(image_path)
510
- if tag_file:
511
- with open(tag_file, 'r', encoding='utf-8') as f:
512
- tags = f.read().strip().split(',')
513
-
514
- if args.random_tags is not None:
515
- # Randomly select tags if --random-tags is specified
516
- num_tags = min(args.random_tags, len(tags))
517
- tags = random.sample(tags, num_tags)
518
- elif args.feed_from_tags > 0:
519
- # Use specified number of tags if --feed-from-tags has a positive value
520
- tags = tags[:args.feed_from_tags]
521
-
522
- tag_string = ', '.join(tags)
523
- custom_prompt = f"Write a descriptive caption for this image in a formal tone. Use these tags as context clues to construct your caption: {tag_string}"
524
-
525
- caption = joy_caption_model.process_image(
526
- input_image,
527
- "custom",
528
- args.caption_tone,
529
- args.caption_length,
530
- custom_prompt=custom_prompt
531
- )
532
- else:
533
- caption = joy_caption_model.process_image(
534
- input_image,
535
- args.caption_type,
536
- args.caption_tone,
537
- args.caption_length
538
- )
539
- else:
540
- caption = joy_caption_model.process_image(
541
- input_image,
542
- args.caption_type,
543
- args.caption_tone,
544
- args.caption_length
545
- )
546
 
547
  # Strip commas if the --dont-strip-commas flag is not set
548
  if not args.dont_strip_commas:
@@ -560,6 +527,155 @@ def main():
560
  f.write(caption)
561
  print(f"Caption saved to {caption_file}")
562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  def find_tag_file(image_path):
564
  """
565
  Find the corresponding .txt file for the given image path.
 
18
  import argparse
19
  import re
20
  import random
21
+ from collections import Counter
22
  from pathlib import Path
23
  from PIL import Image
24
  import pillow_jxl
 
33
  PreTrainedTokenizerFast,
34
  )
35
  from torch import nn
36
+ from e6db_reader import TagNormalizer, TagSetNormalizer, tag_category2id, tag_rank_to_freq
37
 
38
  CLIP_PATH = "google/siglip-so400m-patch14-384"
39
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
 
82
 
83
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
84
 
 
 
85
  class ImageAdapter(nn.Module):
86
  """
87
  Custom image adapter module for processing CLIP vision outputs.
 
274
  caption_type: str,
275
  caption_tone: str,
276
  caption_length: str | int,
277
+ custom_prompt: str | None = None) -> str:
278
  """
279
  Process an input image and generate a caption based on specified parameters.
280
  """
281
  torch.cuda.empty_cache()
282
 
283
+ if custom_prompt is not None:
284
  prompt_str = custom_prompt
285
  else:
286
  prompt_str = self._get_prompt_string(caption_type, caption_tone, caption_length)
 
469
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
470
 
471
  print('Loading e621 tag data')
472
+ tagset_normalizer = make_tagset_normalizer()
473
 
474
  # Initialize and load models
475
  joy_caption_model = JoyCaptionModel()
 
494
  input_image = Image.open(image_path).convert("RGB")
495
 
496
  # Use custom prompt if specified
497
+ custom_prompt = None
498
  if args.caption_type == "custom":
499
+ custom_prompt = args.custom_prompt
500
+ elif args.feed_from_tags is not None:
501
+ custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer)
502
+
503
+ print(f"Custom prompt: {custom_prompt}")
504
+ continue
505
+
506
+ caption = joy_caption_model.process_image(
507
+ input_image,
508
+ args.caption_type,
509
+ args.caption_tone,
510
+ args.caption_length,
511
+ custom_prompt=custom_prompt
512
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
  # Strip commas if the --dont-strip-commas flag is not set
515
  if not args.dont_strip_commas:
 
527
  f.write(caption)
528
  print(f"Caption saved to {caption_file}")
529
 
530
+
531
+ RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
532
+ E6DB_DATA = Path(__file__).resolve().parent / "data"
533
+
534
+ def make_tagset_normalizer():
535
+ """
536
+ Create a TagSetNormalizer for encoding/decoding tags to and from integers.
537
+ Configures it based on the provided config.
538
+ """
539
+ # This loads all the aliases and implications
540
+ tagset_normalizer = TagSetNormalizer(E6DB_DATA)
541
+
542
+ tagid2cat = tagset_normalizer.tag_normalizer.tag_categories
543
+ cat_artist = tag_category2id["artist"]
544
+ cat2suffix = {
545
+ tag_category2id["character"]: "_(character)",
546
+ tag_category2id["lore"]: "_(lore)",
547
+ tag_category2id["species"]: "_(species)",
548
+ tag_category2id["copyright"]: "_(copyright)",
549
+ }
550
+
551
+ # Create additional aliases for tags using simple rules
552
+ def input_map(tag, tid):
553
+ # Make an alias without parentheses, it might conflict but we'll handle
554
+ # it depending on `on_alias_conflict` config value.
555
+ without_suffix = RE_PARENS_SUFFIX.sub("", tag)
556
+ had_suffix = tag != without_suffix
557
+ if had_suffix:
558
+ yield without_suffix
559
+
560
+ # Add an alias with the suffix (special case for artist)
561
+ cat = tagid2cat[tid] if tid is not None else -1
562
+ if cat == cat_artist:
563
+ artist = without_suffix.removeprefix("by_")
564
+ if artist != without_suffix:
565
+ yield artist
566
+ if not had_suffix:
567
+ yield f"{artist}_(artist)"
568
+ else:
569
+ yield f"by_{artist}"
570
+ if not had_suffix:
571
+ yield f"by_{artist}_(artist)"
572
+ elif not had_suffix:
573
+ suffix = cat2suffix.get(cat)
574
+ if suffix is not None:
575
+ yield f"{without_suffix}{suffix}"
576
+
577
+ # Recognize tags where ':' were replaced by a space (aspect ratio)
578
+ if ":" in tag:
579
+ yield tag.replace(":", "_")
580
+
581
+ return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
582
+
583
+
584
+
585
+ def format_nl_list(l):
586
+ n = len(l)
587
+ assert n > 0
588
+ if n == 1:
589
+ return l[0]
590
+ elif n == 2:
591
+ return f"{l[0]} and {l[1]}"
592
+ else: # n > 2
593
+ *head, last = l
594
+ return ', '.join(head) + ', and ' + last
595
+
596
+ TAG_SPECIES = tag_category2id['species']
597
+ TAG_CHARACTER = tag_category2id['character']
598
+ TAG_ARTIST = tag_category2id['artist']
599
+ TAG_COPYRIGHT = tag_category2id['copyright']
600
+ TAG_META = tag_category2id['meta']
601
+ TAG_FREQ_THRESH = 0
602
+
603
+ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer):
604
+ tag_file = find_tag_file(image_path)
605
+ if tag_file is None:
606
+ return None
607
+
608
+ with open(tag_file, 'r', encoding='utf-8') as f:
609
+ tags = f.read().lower().split(',')
610
+
611
+ tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
612
+ encode = tagset_normalizer.tag_normalizer.encode
613
+
614
+ # These lists contain tuples (freq, tag, tag_id)
615
+ tag_by_category = {cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]}
616
+ other_tags = []
617
+ implied = set()
618
+ for tag in tags:
619
+ tag = tag.strip()
620
+ # Encode the tag into a numerical id
621
+ tag_id = encode(tag.replace(' ', '_'))
622
+ if tag_id is None:
623
+ other_tags.append((0, tag, None))
624
+ implied.update(tagset_normalizer.implications_rej.get(tag_id, ()))
625
+ continue
626
+ # Get the category of the tag
627
+ cat_id = tag_id_to_cat_id[tag_id]
628
+ # Skip meta tags
629
+ if cat_id == TAG_META:
630
+ continue
631
+ implied.update(tagset_normalizer.implications.get(tag_id, ()))
632
+ # Get the frequency of the tag
633
+ freq = tag_rank_to_freq(tag_id)
634
+ if freq < TAG_FREQ_THRESH:
635
+ continue
636
+ tag_by_category.get(cat_id, other_tags).append((freq, tag, tag_id))
637
+
638
+ other_tags = sorted((freq, tag) for freq, tag, tag_id in other_tags if tag_id not in implied)
639
+ for cat_id, cat_list in tag_by_category.items():
640
+ tag_by_category[cat_id] = sorted((freq, tag) for freq, tag, tag_id in cat_list if tag_id not in implied)
641
+
642
+ if args.random_tags is not None:
643
+ # Randomly select tags if --random-tags is specified
644
+ num_tags = min(args.random_tags, len(other_tags))
645
+ other_tags = random.sample(tags[:round(args.random_tags * 1.5)], num_tags)
646
+ elif args.feed_from_tags > 0:
647
+ # Use specified number of tags if --feed-from-tags has a positive value
648
+ other_tags = other_tags[:args.feed_from_tags]
649
+
650
+ # Prepare sentence pieces
651
+ artist_tag = tag_by_category[TAG_ARTIST]
652
+ if artist_tag:
653
+ artist_txt = f' by {format_nl_list([tag.removeprefix('by ') for _, tag in artist_tag[:4]])}'
654
+ else:
655
+ artist_txt = ''
656
+ character_tag = tag_by_category[TAG_CHARACTER]
657
+ if character_tag:
658
+ character_txt = f' named {format_nl_list([tag for _, tag in character_tag[:4]])}'
659
+ else:
660
+ character_txt = ''
661
+ species_tag = tag_by_category[TAG_SPECIES]
662
+ if species_tag:
663
+ species_txt = f' of{" a" if len(character_tag) <= 1 and len(species_tag) <= 1 else ""} {format_nl_list([tag for _, tag in species_tag[:4]])}'
664
+ else:
665
+ if character_tag:
666
+ species_txt = f' of{" a character" if len(character_tag) <= 1 else " characters"}'
667
+ else:
668
+ species_txt = ''
669
+ copyright_tag = tag_by_category[TAG_COPYRIGHT]
670
+ if copyright_tag:
671
+ copyright_txt = f' from {format_nl_list([tag for _, tag in copyright_tag[:4]])}'
672
+ else:
673
+ copyright_txt = ''
674
+
675
+ tag_string = ', '.join(tag for _, tag in other_tags)
676
+ custom_prompt = f"Write a descriptive caption for this image{artist_txt}{species_txt}{character_txt}{copyright_txt} in a formal tone. Use these tags as context clues to construct your caption: {tag_string}"
677
+ return custom_prompt
678
+
679
  def find_tag_file(image_path):
680
  """
681
  Find the corresponding .txt file for the given image path.