joy: use tag categories for building prompts
Browse files- .gitignore +2 -0
- __pycache__/e6db_reader.cpython-312.pyc +0 -0
- data/implications.json.gz +3 -0
- data/implications_rej.json.gz +3 -0
- demo.py +0 -7
- e6db_reader.py +1 -1
- joy +169 -53
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
|
__pycache__/e6db_reader.cpython-312.pyc
DELETED
Binary file (16.5 kB)
|
|
data/implications.json.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6240e9f23bacc42bcfccdfa3e86a439a4b0a489ee142c9ac855f12027c0657e7
|
3 |
+
size 228337
|
data/implications_rej.json.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33c9a2d1ed8f60d6f6122a6b2a816043de2ad029ca0fc2ad39f1fbd705e6beaa
|
3 |
+
size 94416
|
demo.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
from e6db_reader import TagNormalizer, tag_categories, tag_category2id
|
2 |
-
|
3 |
-
tn = TagNormalizer('data')
|
4 |
-
tn.map_inputs(lambda tag, tid: tag.replace('_', ' '))
|
5 |
-
|
6 |
-
for tag in ['pokemon', 'pikachu', 'charizard', 'loona']:
|
7 |
-
print(tag, tn.get_category(tag))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
e6db_reader.py
CHANGED
@@ -93,7 +93,7 @@ def load_implications(data_dir):
|
|
93 |
|
94 |
def tag_rank_to_freq(rank: int) -> float:
|
95 |
"""Approximate the frequency of a tag given its rank"""
|
96 |
-
return math.exp(26.4284 * math.tanh(2.93505 * rank ** (-0.136501)) - 11.492)
|
97 |
|
98 |
|
99 |
def tag_freq_to_rank(freq: int) -> float:
|
|
|
93 |
|
94 |
def tag_rank_to_freq(rank: int) -> float:
|
95 |
"""Approximate the frequency of a tag given its rank"""
|
96 |
+
return math.exp(26.4284 * math.tanh(2.93505 * max(1, rank) ** (-0.136501)) - 11.492)
|
97 |
|
98 |
|
99 |
def tag_freq_to_rank(freq: int) -> float:
|
joy
CHANGED
@@ -18,6 +18,7 @@ import os
|
|
18 |
import argparse
|
19 |
import re
|
20 |
import random
|
|
|
21 |
from pathlib import Path
|
22 |
from PIL import Image
|
23 |
import pillow_jxl
|
@@ -32,7 +33,7 @@ from transformers import (
|
|
32 |
PreTrainedTokenizerFast,
|
33 |
)
|
34 |
from torch import nn
|
35 |
-
from e6db_reader import TagNormalizer
|
36 |
|
37 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
38 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
@@ -81,8 +82,6 @@ CAPTION_TYPE_MAP = {
|
|
81 |
|
82 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
83 |
|
84 |
-
E6DB_DATA = Path(__file__).resolve().parent / "data"
|
85 |
-
|
86 |
class ImageAdapter(nn.Module):
|
87 |
"""
|
88 |
Custom image adapter module for processing CLIP vision outputs.
|
@@ -275,13 +274,13 @@ class JoyCaptionModel:
|
|
275 |
caption_type: str,
|
276 |
caption_tone: str,
|
277 |
caption_length: str | int,
|
278 |
-
custom_prompt: str = None) -> str:
|
279 |
"""
|
280 |
Process an input image and generate a caption based on specified parameters.
|
281 |
"""
|
282 |
torch.cuda.empty_cache()
|
283 |
|
284 |
-
if
|
285 |
prompt_str = custom_prompt
|
286 |
else:
|
287 |
prompt_str = self._get_prompt_string(caption_type, caption_tone, caption_length)
|
@@ -470,7 +469,7 @@ def main():
|
|
470 |
parser.error("--random-tags can only be used when --feed-from-tags is enabled")
|
471 |
|
472 |
print('Loading e621 tag data')
|
473 |
-
|
474 |
|
475 |
# Initialize and load models
|
476 |
joy_caption_model = JoyCaptionModel()
|
@@ -495,54 +494,22 @@ def main():
|
|
495 |
input_image = Image.open(image_path).convert("RGB")
|
496 |
|
497 |
# Use custom prompt if specified
|
|
|
498 |
if args.caption_type == "custom":
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
if args.random_tags is not None:
|
515 |
-
# Randomly select tags if --random-tags is specified
|
516 |
-
num_tags = min(args.random_tags, len(tags))
|
517 |
-
tags = random.sample(tags, num_tags)
|
518 |
-
elif args.feed_from_tags > 0:
|
519 |
-
# Use specified number of tags if --feed-from-tags has a positive value
|
520 |
-
tags = tags[:args.feed_from_tags]
|
521 |
-
|
522 |
-
tag_string = ', '.join(tags)
|
523 |
-
custom_prompt = f"Write a descriptive caption for this image in a formal tone. Use these tags as context clues to construct your caption: {tag_string}"
|
524 |
-
|
525 |
-
caption = joy_caption_model.process_image(
|
526 |
-
input_image,
|
527 |
-
"custom",
|
528 |
-
args.caption_tone,
|
529 |
-
args.caption_length,
|
530 |
-
custom_prompt=custom_prompt
|
531 |
-
)
|
532 |
-
else:
|
533 |
-
caption = joy_caption_model.process_image(
|
534 |
-
input_image,
|
535 |
-
args.caption_type,
|
536 |
-
args.caption_tone,
|
537 |
-
args.caption_length
|
538 |
-
)
|
539 |
-
else:
|
540 |
-
caption = joy_caption_model.process_image(
|
541 |
-
input_image,
|
542 |
-
args.caption_type,
|
543 |
-
args.caption_tone,
|
544 |
-
args.caption_length
|
545 |
-
)
|
546 |
|
547 |
# Strip commas if the --dont-strip-commas flag is not set
|
548 |
if not args.dont_strip_commas:
|
@@ -560,6 +527,155 @@ def main():
|
|
560 |
f.write(caption)
|
561 |
print(f"Caption saved to {caption_file}")
|
562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
def find_tag_file(image_path):
|
564 |
"""
|
565 |
Find the corresponding .txt file for the given image path.
|
|
|
18 |
import argparse
|
19 |
import re
|
20 |
import random
|
21 |
+
from collections import Counter
|
22 |
from pathlib import Path
|
23 |
from PIL import Image
|
24 |
import pillow_jxl
|
|
|
33 |
PreTrainedTokenizerFast,
|
34 |
)
|
35 |
from torch import nn
|
36 |
+
from e6db_reader import TagNormalizer, TagSetNormalizer, tag_category2id, tag_rank_to_freq
|
37 |
|
38 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
39 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
|
|
82 |
|
83 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
84 |
|
|
|
|
|
85 |
class ImageAdapter(nn.Module):
|
86 |
"""
|
87 |
Custom image adapter module for processing CLIP vision outputs.
|
|
|
274 |
caption_type: str,
|
275 |
caption_tone: str,
|
276 |
caption_length: str | int,
|
277 |
+
custom_prompt: str | None = None) -> str:
|
278 |
"""
|
279 |
Process an input image and generate a caption based on specified parameters.
|
280 |
"""
|
281 |
torch.cuda.empty_cache()
|
282 |
|
283 |
+
if custom_prompt is not None:
|
284 |
prompt_str = custom_prompt
|
285 |
else:
|
286 |
prompt_str = self._get_prompt_string(caption_type, caption_tone, caption_length)
|
|
|
469 |
parser.error("--random-tags can only be used when --feed-from-tags is enabled")
|
470 |
|
471 |
print('Loading e621 tag data')
|
472 |
+
tagset_normalizer = make_tagset_normalizer()
|
473 |
|
474 |
# Initialize and load models
|
475 |
joy_caption_model = JoyCaptionModel()
|
|
|
494 |
input_image = Image.open(image_path).convert("RGB")
|
495 |
|
496 |
# Use custom prompt if specified
|
497 |
+
custom_prompt = None
|
498 |
if args.caption_type == "custom":
|
499 |
+
custom_prompt = args.custom_prompt
|
500 |
+
elif args.feed_from_tags is not None:
|
501 |
+
custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer)
|
502 |
+
|
503 |
+
print(f"Custom prompt: {custom_prompt}")
|
504 |
+
continue
|
505 |
+
|
506 |
+
caption = joy_caption_model.process_image(
|
507 |
+
input_image,
|
508 |
+
args.caption_type,
|
509 |
+
args.caption_tone,
|
510 |
+
args.caption_length,
|
511 |
+
custom_prompt=custom_prompt
|
512 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
|
514 |
# Strip commas if the --dont-strip-commas flag is not set
|
515 |
if not args.dont_strip_commas:
|
|
|
527 |
f.write(caption)
|
528 |
print(f"Caption saved to {caption_file}")
|
529 |
|
530 |
+
|
531 |
+
RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
|
532 |
+
E6DB_DATA = Path(__file__).resolve().parent / "data"
|
533 |
+
|
534 |
+
def make_tagset_normalizer():
|
535 |
+
"""
|
536 |
+
Create a TagSetNormalizer for encoding/decoding tags to and from integers.
|
537 |
+
Configures it based on the provided config.
|
538 |
+
"""
|
539 |
+
# This loads all the aliases and implications
|
540 |
+
tagset_normalizer = TagSetNormalizer(E6DB_DATA)
|
541 |
+
|
542 |
+
tagid2cat = tagset_normalizer.tag_normalizer.tag_categories
|
543 |
+
cat_artist = tag_category2id["artist"]
|
544 |
+
cat2suffix = {
|
545 |
+
tag_category2id["character"]: "_(character)",
|
546 |
+
tag_category2id["lore"]: "_(lore)",
|
547 |
+
tag_category2id["species"]: "_(species)",
|
548 |
+
tag_category2id["copyright"]: "_(copyright)",
|
549 |
+
}
|
550 |
+
|
551 |
+
# Create additional aliases for tags using simple rules
|
552 |
+
def input_map(tag, tid):
|
553 |
+
# Make an alias without parentheses, it might conflict but we'll handle
|
554 |
+
# it depending on `on_alias_conflict` config value.
|
555 |
+
without_suffix = RE_PARENS_SUFFIX.sub("", tag)
|
556 |
+
had_suffix = tag != without_suffix
|
557 |
+
if had_suffix:
|
558 |
+
yield without_suffix
|
559 |
+
|
560 |
+
# Add an alias with the suffix (special case for artist)
|
561 |
+
cat = tagid2cat[tid] if tid is not None else -1
|
562 |
+
if cat == cat_artist:
|
563 |
+
artist = without_suffix.removeprefix("by_")
|
564 |
+
if artist != without_suffix:
|
565 |
+
yield artist
|
566 |
+
if not had_suffix:
|
567 |
+
yield f"{artist}_(artist)"
|
568 |
+
else:
|
569 |
+
yield f"by_{artist}"
|
570 |
+
if not had_suffix:
|
571 |
+
yield f"by_{artist}_(artist)"
|
572 |
+
elif not had_suffix:
|
573 |
+
suffix = cat2suffix.get(cat)
|
574 |
+
if suffix is not None:
|
575 |
+
yield f"{without_suffix}{suffix}"
|
576 |
+
|
577 |
+
# Recognize tags where ':' were replaced by a space (aspect ratio)
|
578 |
+
if ":" in tag:
|
579 |
+
yield tag.replace(":", "_")
|
580 |
+
|
581 |
+
return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
|
582 |
+
|
583 |
+
|
584 |
+
|
585 |
+
def format_nl_list(l):
|
586 |
+
n = len(l)
|
587 |
+
assert n > 0
|
588 |
+
if n == 1:
|
589 |
+
return l[0]
|
590 |
+
elif n == 2:
|
591 |
+
return f"{l[0]} and {l[1]}"
|
592 |
+
else: # n > 2
|
593 |
+
*head, last = l
|
594 |
+
return ', '.join(head) + ', and ' + last
|
595 |
+
|
596 |
+
TAG_SPECIES = tag_category2id['species']
|
597 |
+
TAG_CHARACTER = tag_category2id['character']
|
598 |
+
TAG_ARTIST = tag_category2id['artist']
|
599 |
+
TAG_COPYRIGHT = tag_category2id['copyright']
|
600 |
+
TAG_META = tag_category2id['meta']
|
601 |
+
TAG_FREQ_THRESH = 0
|
602 |
+
|
603 |
+
def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer):
|
604 |
+
tag_file = find_tag_file(image_path)
|
605 |
+
if tag_file is None:
|
606 |
+
return None
|
607 |
+
|
608 |
+
with open(tag_file, 'r', encoding='utf-8') as f:
|
609 |
+
tags = f.read().lower().split(',')
|
610 |
+
|
611 |
+
tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
|
612 |
+
encode = tagset_normalizer.tag_normalizer.encode
|
613 |
+
|
614 |
+
# These lists contain tuples (freq, tag, tag_id)
|
615 |
+
tag_by_category = {cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]}
|
616 |
+
other_tags = []
|
617 |
+
implied = set()
|
618 |
+
for tag in tags:
|
619 |
+
tag = tag.strip()
|
620 |
+
# Encode the tag into a numerical id
|
621 |
+
tag_id = encode(tag.replace(' ', '_'))
|
622 |
+
if tag_id is None:
|
623 |
+
other_tags.append((0, tag, None))
|
624 |
+
implied.update(tagset_normalizer.implications_rej.get(tag_id, ()))
|
625 |
+
continue
|
626 |
+
# Get the category of the tag
|
627 |
+
cat_id = tag_id_to_cat_id[tag_id]
|
628 |
+
# Skip meta tags
|
629 |
+
if cat_id == TAG_META:
|
630 |
+
continue
|
631 |
+
implied.update(tagset_normalizer.implications.get(tag_id, ()))
|
632 |
+
# Get the frequency of the tag
|
633 |
+
freq = tag_rank_to_freq(tag_id)
|
634 |
+
if freq < TAG_FREQ_THRESH:
|
635 |
+
continue
|
636 |
+
tag_by_category.get(cat_id, other_tags).append((freq, tag, tag_id))
|
637 |
+
|
638 |
+
other_tags = sorted((freq, tag) for freq, tag, tag_id in other_tags if tag_id not in implied)
|
639 |
+
for cat_id, cat_list in tag_by_category.items():
|
640 |
+
tag_by_category[cat_id] = sorted((freq, tag) for freq, tag, tag_id in cat_list if tag_id not in implied)
|
641 |
+
|
642 |
+
if args.random_tags is not None:
|
643 |
+
# Randomly select tags if --random-tags is specified
|
644 |
+
num_tags = min(args.random_tags, len(other_tags))
|
645 |
+
other_tags = random.sample(tags[:round(args.random_tags * 1.5)], num_tags)
|
646 |
+
elif args.feed_from_tags > 0:
|
647 |
+
# Use specified number of tags if --feed-from-tags has a positive value
|
648 |
+
other_tags = other_tags[:args.feed_from_tags]
|
649 |
+
|
650 |
+
# Prepare sentence pieces
|
651 |
+
artist_tag = tag_by_category[TAG_ARTIST]
|
652 |
+
if artist_tag:
|
653 |
+
artist_txt = f' by {format_nl_list([tag.removeprefix('by ') for _, tag in artist_tag[:4]])}'
|
654 |
+
else:
|
655 |
+
artist_txt = ''
|
656 |
+
character_tag = tag_by_category[TAG_CHARACTER]
|
657 |
+
if character_tag:
|
658 |
+
character_txt = f' named {format_nl_list([tag for _, tag in character_tag[:4]])}'
|
659 |
+
else:
|
660 |
+
character_txt = ''
|
661 |
+
species_tag = tag_by_category[TAG_SPECIES]
|
662 |
+
if species_tag:
|
663 |
+
species_txt = f' of{" a" if len(character_tag) <= 1 and len(species_tag) <= 1 else ""} {format_nl_list([tag for _, tag in species_tag[:4]])}'
|
664 |
+
else:
|
665 |
+
if character_tag:
|
666 |
+
species_txt = f' of{" a character" if len(character_tag) <= 1 else " characters"}'
|
667 |
+
else:
|
668 |
+
species_txt = ''
|
669 |
+
copyright_tag = tag_by_category[TAG_COPYRIGHT]
|
670 |
+
if copyright_tag:
|
671 |
+
copyright_txt = f' from {format_nl_list([tag for _, tag in copyright_tag[:4]])}'
|
672 |
+
else:
|
673 |
+
copyright_txt = ''
|
674 |
+
|
675 |
+
tag_string = ', '.join(tag for _, tag in other_tags)
|
676 |
+
custom_prompt = f"Write a descriptive caption for this image{artist_txt}{species_txt}{character_txt}{copyright_txt} in a formal tone. Use these tags as context clues to construct your caption: {tag_string}"
|
677 |
+
return custom_prompt
|
678 |
+
|
679 |
def find_tag_file(image_path):
|
680 |
"""
|
681 |
Find the corresponding .txt file for the given image path.
|