joy: some fixes and niceties
Browse files
joy
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
"""
|
5 |
-
JoyCaption Alpha
|
6 |
|
7 |
This module provides functionality for generating captions for images using a
|
8 |
combination of CLIP, LLM, and custom image adapters. It supports various
|
@@ -34,52 +34,53 @@ from transformers import (
|
|
34 |
)
|
35 |
from torch import nn
|
36 |
from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
|
|
|
37 |
|
38 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
39 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
40 |
CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808"
|
41 |
CAPTION_TYPE_MAP = {
|
42 |
-
"
|
43 |
"Write a descriptive caption for this image in a formal tone.",
|
44 |
"Write a descriptive caption for this image in a formal tone within {word_count} words.",
|
45 |
"Write a {length} descriptive caption for this image in a formal tone.",
|
46 |
],
|
47 |
-
"
|
48 |
"Write a descriptive caption for this image in a casual tone.",
|
49 |
"Write a descriptive caption for this image in a casual tone within {word_count} words.",
|
50 |
"Write a {length} descriptive caption for this image in a casual tone.",
|
51 |
],
|
52 |
-
"
|
53 |
"Write a stable diffusion prompt for this image.",
|
54 |
"Write a stable diffusion prompt for this image within {word_count} words.",
|
55 |
"Write a {length} stable diffusion prompt for this image.",
|
56 |
],
|
57 |
-
"
|
58 |
"Write a MidJourney prompt for this image.",
|
59 |
"Write a MidJourney prompt for this image within {word_count} words.",
|
60 |
"Write a {length} MidJourney prompt for this image.",
|
61 |
],
|
62 |
-
"
|
63 |
"Write a list of Booru tags for this image.",
|
64 |
"Write a list of Booru tags for this image within {word_count} words.",
|
65 |
"Write a {length} list of Booru tags for this image.",
|
66 |
],
|
67 |
-
"
|
68 |
"Write a list of Booru-like tags for this image.",
|
69 |
"Write a list of Booru-like tags for this image within {word_count} words.",
|
70 |
"Write a {length} list of Booru-like tags for this image.",
|
71 |
],
|
72 |
-
"
|
73 |
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
|
74 |
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
|
75 |
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
|
76 |
],
|
77 |
-
"
|
78 |
"Write a caption for this image as though it were a product listing.",
|
79 |
"Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
|
80 |
"Write a {length} caption for this image as though it were a product listing.",
|
81 |
],
|
82 |
-
"
|
83 |
"Write a caption for this image as if it were being used for a social media post.",
|
84 |
"Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
|
85 |
"Write a {length} caption for this image as if it were being used for a social media post.",
|
@@ -208,202 +209,11 @@ class ImageAdapter(nn.Module):
|
|
208 |
).squeeze(0)
|
209 |
|
210 |
|
211 |
-
STOP_WORDS: set[str] =
|
212 |
-
"the"
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
"or",
|
217 |
-
"but",
|
218 |
-
"in",
|
219 |
-
"on",
|
220 |
-
"at",
|
221 |
-
"to",
|
222 |
-
"for",
|
223 |
-
"of",
|
224 |
-
"with",
|
225 |
-
"by",
|
226 |
-
"from",
|
227 |
-
"up",
|
228 |
-
"down",
|
229 |
-
"is",
|
230 |
-
"are",
|
231 |
-
"was",
|
232 |
-
"were",
|
233 |
-
"be",
|
234 |
-
"been",
|
235 |
-
"being",
|
236 |
-
"have",
|
237 |
-
"has",
|
238 |
-
"had",
|
239 |
-
"do",
|
240 |
-
"does",
|
241 |
-
"did",
|
242 |
-
"will",
|
243 |
-
"would",
|
244 |
-
"shall",
|
245 |
-
"should",
|
246 |
-
"can",
|
247 |
-
"could",
|
248 |
-
"may",
|
249 |
-
"might",
|
250 |
-
"must",
|
251 |
-
"ought",
|
252 |
-
"i",
|
253 |
-
"you",
|
254 |
-
"he",
|
255 |
-
"she",
|
256 |
-
"it",
|
257 |
-
"we",
|
258 |
-
"they",
|
259 |
-
"them",
|
260 |
-
"their",
|
261 |
-
"this",
|
262 |
-
"that",
|
263 |
-
"these",
|
264 |
-
"those",
|
265 |
-
"am",
|
266 |
-
"is",
|
267 |
-
"are",
|
268 |
-
"was",
|
269 |
-
"were",
|
270 |
-
"be",
|
271 |
-
"been",
|
272 |
-
"being",
|
273 |
-
"have",
|
274 |
-
"has",
|
275 |
-
"had",
|
276 |
-
"do",
|
277 |
-
"does",
|
278 |
-
"did",
|
279 |
-
"will",
|
280 |
-
"would",
|
281 |
-
"shall",
|
282 |
-
"should",
|
283 |
-
"can",
|
284 |
-
"could",
|
285 |
-
"may",
|
286 |
-
"might",
|
287 |
-
"must",
|
288 |
-
"ought",
|
289 |
-
"i'm",
|
290 |
-
"you're",
|
291 |
-
"he's",
|
292 |
-
"she's",
|
293 |
-
"it's",
|
294 |
-
"we're",
|
295 |
-
"they're",
|
296 |
-
"i've",
|
297 |
-
"you've",
|
298 |
-
"we've",
|
299 |
-
"they've",
|
300 |
-
"i'd",
|
301 |
-
"you'd",
|
302 |
-
"he'd",
|
303 |
-
"she'd",
|
304 |
-
"we'd",
|
305 |
-
"they'd",
|
306 |
-
"i'll",
|
307 |
-
"you'll",
|
308 |
-
"he'll",
|
309 |
-
"she'll",
|
310 |
-
"we'll",
|
311 |
-
"they'll",
|
312 |
-
"isn't",
|
313 |
-
"aren't",
|
314 |
-
"wasn't",
|
315 |
-
"weren't",
|
316 |
-
"hasn't",
|
317 |
-
"haven't",
|
318 |
-
"hadn't",
|
319 |
-
"doesn't",
|
320 |
-
"don't",
|
321 |
-
"didn't",
|
322 |
-
"won't",
|
323 |
-
"wouldn't",
|
324 |
-
"shan't",
|
325 |
-
"shouldn't",
|
326 |
-
"can't",
|
327 |
-
"cannot",
|
328 |
-
"couldn't",
|
329 |
-
"mustn't",
|
330 |
-
"let's",
|
331 |
-
"that's",
|
332 |
-
"who's",
|
333 |
-
"what's",
|
334 |
-
"here's",
|
335 |
-
"there's",
|
336 |
-
"when's",
|
337 |
-
"where's",
|
338 |
-
"why's",
|
339 |
-
"how's",
|
340 |
-
"a",
|
341 |
-
"an",
|
342 |
-
"the",
|
343 |
-
"and",
|
344 |
-
"but",
|
345 |
-
"if",
|
346 |
-
"or",
|
347 |
-
"because",
|
348 |
-
"as",
|
349 |
-
"until",
|
350 |
-
"while",
|
351 |
-
"of",
|
352 |
-
"at",
|
353 |
-
"by",
|
354 |
-
"for",
|
355 |
-
"with",
|
356 |
-
"about",
|
357 |
-
"against",
|
358 |
-
"between",
|
359 |
-
"into",
|
360 |
-
"through",
|
361 |
-
"during",
|
362 |
-
"before",
|
363 |
-
"after",
|
364 |
-
"above",
|
365 |
-
"below",
|
366 |
-
"to",
|
367 |
-
"from",
|
368 |
-
"up",
|
369 |
-
"down",
|
370 |
-
"in",
|
371 |
-
"out",
|
372 |
-
"on",
|
373 |
-
"off",
|
374 |
-
"over",
|
375 |
-
"under",
|
376 |
-
"again",
|
377 |
-
"further",
|
378 |
-
"then",
|
379 |
-
"once",
|
380 |
-
"here",
|
381 |
-
"there",
|
382 |
-
"when",
|
383 |
-
"where",
|
384 |
-
"why",
|
385 |
-
"how",
|
386 |
-
"all",
|
387 |
-
"any",
|
388 |
-
"both",
|
389 |
-
"each",
|
390 |
-
"few",
|
391 |
-
"more",
|
392 |
-
"most",
|
393 |
-
"other",
|
394 |
-
"some",
|
395 |
-
"such",
|
396 |
-
"no",
|
397 |
-
"nor",
|
398 |
-
"not",
|
399 |
-
"only",
|
400 |
-
"own",
|
401 |
-
"same",
|
402 |
-
"so",
|
403 |
-
"than",
|
404 |
-
"too",
|
405 |
-
"very",
|
406 |
-
}
|
407 |
|
408 |
|
409 |
class JoyCaptionModel:
|
@@ -440,12 +250,12 @@ class JoyCaptionModel:
|
|
440 |
"""
|
441 |
Load and initialize all required models (CLIP, LLM, image adapter).
|
442 |
"""
|
443 |
-
|
444 |
self.clip_model = AutoModel.from_pretrained(CLIP_PATH)
|
445 |
self.clip_model = self.clip_model.vision_model
|
446 |
|
447 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
448 |
-
|
449 |
checkpoint = torch.load(
|
450 |
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
|
451 |
)
|
@@ -459,15 +269,15 @@ class JoyCaptionModel:
|
|
459 |
self.clip_model.requires_grad_(False)
|
460 |
self.clip_model.to("cuda")
|
461 |
|
462 |
-
|
463 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
|
464 |
assert isinstance(
|
465 |
self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
|
466 |
)
|
467 |
|
468 |
-
|
469 |
if (CHECKPOINT_PATH / "text_model").exists():
|
470 |
-
|
471 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
472 |
CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
|
473 |
)
|
@@ -478,7 +288,7 @@ class JoyCaptionModel:
|
|
478 |
|
479 |
self.text_model.eval()
|
480 |
|
481 |
-
|
482 |
self.image_adapter = ImageAdapter(
|
483 |
self.clip_model.config.hidden_size,
|
484 |
self.text_model.config.hidden_size,
|
@@ -497,9 +307,7 @@ class JoyCaptionModel:
|
|
497 |
def process_image(
|
498 |
self,
|
499 |
input_image: Image.Image,
|
500 |
-
|
501 |
-
caption_length: str | int,
|
502 |
-
custom_prompt: str | None = None,
|
503 |
) -> Tuple[str, float]:
|
504 |
"""
|
505 |
Process an input image and generate a caption based on specified parameters.
|
@@ -509,12 +317,7 @@ class JoyCaptionModel:
|
|
509 |
Tuple[str, float]: The generated caption and its entropy.
|
510 |
"""
|
511 |
torch.cuda.empty_cache()
|
512 |
-
|
513 |
-
if custom_prompt is not None:
|
514 |
-
prompt_str = custom_prompt
|
515 |
-
else:
|
516 |
-
prompt_str = self._get_prompt_string(caption_type, caption_length)
|
517 |
-
print(f"Prompt: {prompt_str}")
|
518 |
|
519 |
pixel_values = self._preprocess_image(input_image)
|
520 |
|
@@ -535,9 +338,7 @@ class JoyCaptionModel:
|
|
535 |
def generate_valid_caption(
|
536 |
self,
|
537 |
input_image: Image.Image,
|
538 |
-
|
539 |
-
caption_length: str | int,
|
540 |
-
custom_prompt: str | None = None,
|
541 |
*,
|
542 |
limited_words: Dict[str, int] = {"fluffy": 2},
|
543 |
min_sentence_count: int = 3,
|
@@ -550,9 +351,7 @@ class JoyCaptionModel:
|
|
550 |
|
551 |
Args:
|
552 |
input_image (Image.Image): The input image to caption.
|
553 |
-
|
554 |
-
caption_length (str | int): The desired length of the caption.
|
555 |
-
custom_prompt (str | None): A custom prompt for caption generation.
|
556 |
limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 1}.
|
557 |
min_sentence_count (int): Minimum required number of sentences. Default is 3.
|
558 |
max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 15.
|
@@ -570,9 +369,7 @@ class JoyCaptionModel:
|
|
570 |
- The entropy of the caption is below min_entropy
|
571 |
"""
|
572 |
while True:
|
573 |
-
caption, entropy = self.process_image(
|
574 |
-
input_image, caption_type, caption_length, custom_prompt
|
575 |
-
)
|
576 |
words = re.findall(r"\b\w+\b", caption.lower())
|
577 |
word_counts = {
|
578 |
word: words.count(word) for word in set(words) if word not in stop_words
|
@@ -580,11 +377,11 @@ class JoyCaptionModel:
|
|
580 |
sentence_count = len(re.findall(r"[.!?]", caption))
|
581 |
|
582 |
if not re.search(r"\w", caption):
|
583 |
-
|
584 |
f"Retrying: Caption contains only special characters.\nCaption: {caption!r}"
|
585 |
)
|
586 |
elif caption[-1] not in {".", "!", "?"}:
|
587 |
-
|
588 |
f"Retrying: Caption does not end with proper punctuation.\nCaption: {caption!r}"
|
589 |
)
|
590 |
elif any(
|
@@ -596,7 +393,7 @@ class JoyCaptionModel:
|
|
596 |
for word, max_count in limited_words.items()
|
597 |
if caption.lower().count(word) > max_count
|
598 |
]
|
599 |
-
|
600 |
f"Retrying: Limited words exceeded: {', '.join(exceeded_words)}.\nCaption: {caption!r}"
|
601 |
)
|
602 |
elif any(
|
@@ -609,21 +406,22 @@ class JoyCaptionModel:
|
|
609 |
for word, count in word_counts.items()
|
610 |
if count > max_word_repetitions and len(word) > 4
|
611 |
]
|
612 |
-
|
613 |
f"Retrying: Words repeated more than {max_word_repetitions} times: {', '.join(repeated_words)}.\nCaption: {caption!r}"
|
614 |
)
|
615 |
elif sentence_count < min_sentence_count:
|
616 |
-
|
617 |
f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}"
|
618 |
)
|
619 |
elif entropy < min_entropy:
|
620 |
-
|
621 |
f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}"
|
622 |
)
|
623 |
else:
|
624 |
return caption
|
625 |
|
626 |
-
|
|
|
627 |
length = None if caption_length == "any" else caption_length
|
628 |
|
629 |
if isinstance(length, str):
|
@@ -642,13 +440,16 @@ class JoyCaptionModel:
|
|
642 |
else:
|
643 |
raise ValueError(f"Invalid caption length: {length}")
|
644 |
|
|
|
645 |
if caption_type not in CAPTION_TYPE_MAP:
|
646 |
raise ValueError(f"Invalid caption type: {caption_type}")
|
647 |
|
648 |
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
|
|
|
649 |
return prompt_str
|
650 |
|
651 |
-
|
|
|
652 |
"""
|
653 |
Preprocess the input image for the CLIP model.
|
654 |
|
@@ -703,6 +504,7 @@ class JoyCaptionModel:
|
|
703 |
convo_string = self.tokenizer.apply_chat_template(
|
704 |
convo, tokenize=False, add_generation_prompt=True
|
705 |
)
|
|
|
706 |
convo_tokens = self.tokenizer.encode(
|
707 |
convo_string,
|
708 |
return_tensors="pt",
|
@@ -756,7 +558,7 @@ class JoyCaptionModel:
|
|
756 |
input_ids,
|
757 |
inputs_embeds=inputs_embeds,
|
758 |
attention_mask=attention_mask,
|
759 |
-
max_new_tokens=
|
760 |
do_sample=True,
|
761 |
suppress_tokens=None,
|
762 |
repetition_penalty=1.2,
|
@@ -800,6 +602,38 @@ class JoyCaptionModel:
|
|
800 |
return entropy
|
801 |
|
802 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
803 |
def main():
|
804 |
"""
|
805 |
Generate captions for images in a directory
|
@@ -818,7 +652,7 @@ def main():
|
|
818 |
"--caption_type",
|
819 |
type=str,
|
820 |
default="descriptive",
|
821 |
-
choices=
|
822 |
help="Type of caption to generate.",
|
823 |
)
|
824 |
parser.add_argument(
|
@@ -858,25 +692,37 @@ def main():
|
|
858 |
"Only works if --feed-from-tags is enabled."
|
859 |
),
|
860 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
861 |
|
862 |
args = parser.parse_args()
|
863 |
|
|
|
|
|
864 |
# Validate random-tags usage
|
865 |
if args.random_tags is not None and args.feed_from_tags is None:
|
866 |
parser.error("--random-tags can only be used when --feed-from-tags is enabled")
|
867 |
|
868 |
-
|
869 |
-
|
|
|
870 |
|
871 |
-
# Initialize and load models
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
parser.error("--custom_prompt is required when using --caption_type custom")
|
878 |
-
elif args.caption_type != "custom" and args.custom_prompt:
|
879 |
-
parser.error("--custom_prompt can only be used with --caption_type custom")
|
880 |
|
881 |
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
|
882 |
for image_path in Path(args.directory).rglob("*"):
|
@@ -885,31 +731,27 @@ def main():
|
|
885 |
|
886 |
# Skip if the caption file already exists
|
887 |
if caption_file.exists():
|
888 |
-
|
889 |
continue
|
890 |
|
891 |
-
|
|
|
892 |
|
893 |
# Use custom prompt if specified
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
elif args.feed_from_tags is not None:
|
898 |
-
base_prompt = joy_caption_model._get_prompt_string(
|
899 |
-
args.caption_type, args.caption_length
|
900 |
-
)
|
901 |
-
custom_prompt = prompt_from_tags(
|
902 |
-
args, image_path, tagset_normalizer, base_prompt
|
903 |
-
)
|
904 |
|
905 |
-
|
|
|
906 |
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
|
|
913 |
|
914 |
# Strip commas if the --dont-strip-commas flag is not set
|
915 |
if not args.dont_strip_commas:
|
@@ -923,12 +765,12 @@ def main():
|
|
923 |
# Remove all newline characters
|
924 |
caption = caption.replace("\n", " ")
|
925 |
|
926 |
-
|
927 |
|
928 |
# Save the caption to a .caption file
|
929 |
with open(caption_file, "w", encoding="utf-8") as f:
|
930 |
f.write(caption)
|
931 |
-
|
932 |
|
933 |
|
934 |
RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
|
@@ -1005,7 +847,6 @@ TAG_CHARACTER = tag_category2id["character"]
|
|
1005 |
TAG_ARTIST = tag_category2id["artist"]
|
1006 |
TAG_COPYRIGHT = tag_category2id["copyright"]
|
1007 |
TAG_META = tag_category2id["meta"]
|
1008 |
-
TAG_FREQ_THRESH = 0
|
1009 |
|
1010 |
|
1011 |
def prompt_from_tags(
|
@@ -1013,6 +854,8 @@ def prompt_from_tags(
|
|
1013 |
image_path: Path,
|
1014 |
tagset_normalizer: TagSetNormalizer,
|
1015 |
base_prompt: str = "Write a descriptive caption for this image in a formal tone.",
|
|
|
|
|
1016 |
):
|
1017 |
"""
|
1018 |
Generates a prompt from tags associated with the given image.
|
@@ -1023,31 +866,35 @@ def prompt_from_tags(
|
|
1023 |
The path to the image file.
|
1024 |
tagset_normalizer (TagSetNormalizer):
|
1025 |
An instance to normalize the tag set.
|
1026 |
-
|
1027 |
-
Returns:
|
1028 |
-
None
|
1029 |
"""
|
|
|
1030 |
tag_file = find_tag_file(image_path)
|
1031 |
if tag_file is None:
|
1032 |
-
|
|
|
1033 |
|
1034 |
with open(tag_file, "r", encoding="utf-8") as f:
|
1035 |
tags = f.read().lower().split(",")
|
1036 |
|
|
|
1037 |
tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
|
1038 |
encode = tagset_normalizer.tag_normalizer.encode
|
1039 |
|
1040 |
-
#
|
|
|
1041 |
tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
|
1042 |
cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
|
1043 |
}
|
1044 |
other_tags: List[Tuple[int, str, int]] = []
|
1045 |
implied: set = set()
|
|
|
|
|
1046 |
for tag in tags:
|
1047 |
tag = tag.strip()
|
1048 |
# Encode the tag into a numerical id
|
1049 |
tag_id = encode(tag.replace(" ", "_"))
|
1050 |
if tag_id is None:
|
|
|
1051 |
other_tags.append((0, tag, 0))
|
1052 |
implied.update(tagset_normalizer.implications_rej.get(0, ()))
|
1053 |
continue
|
@@ -1056,26 +903,29 @@ def prompt_from_tags(
|
|
1056 |
# Skip meta tags
|
1057 |
if cat_id == TAG_META:
|
1058 |
continue
|
|
|
1059 |
implied.update(tagset_normalizer.implications.get(tag_id, ()))
|
1060 |
# Get the frequency of the tag
|
1061 |
freq = tag_rank_to_freq(tag_id)
|
1062 |
-
if freq <
|
1063 |
continue
|
|
|
1064 |
tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id))
|
1065 |
|
|
|
1066 |
other_tags = sorted(
|
1067 |
-
(
|
1068 |
for freq, tag, tag_id in other_tags
|
1069 |
if tag_id not in implied
|
1070 |
)
|
1071 |
|
|
|
1072 |
for cat_id, cat_list in tag_by_category.items():
|
1073 |
tag_by_category[cat_id] = sorted(
|
1074 |
-
(
|
1075 |
-
for freq, tag, tag_id in cat_list
|
1076 |
-
if tag_id not in implied
|
1077 |
)
|
1078 |
|
|
|
1079 |
if args.random_tags is not None:
|
1080 |
# Randomly select tags if --random-tags is specified
|
1081 |
num_tags = min(args.random_tags, len(other_tags))
|
@@ -1090,7 +940,7 @@ def prompt_from_tags(
|
|
1090 |
# Use specified number of tags if --feed-from-tags has a positive value
|
1091 |
other_tags = other_tags[: args.feed_from_tags]
|
1092 |
|
1093 |
-
# Prepare sentence pieces
|
1094 |
artist_tag = tag_by_category[TAG_ARTIST]
|
1095 |
if artist_tag:
|
1096 |
artist_list = [str(tp[1]).removeprefix("by ") for tp in artist_tag[:4]]
|
@@ -1113,7 +963,9 @@ def prompt_from_tags(
|
|
1113 |
species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
|
1114 |
else:
|
1115 |
if character_tag:
|
1116 |
-
species_txt =
|
|
|
|
|
1117 |
else:
|
1118 |
species_txt = ""
|
1119 |
|
@@ -1123,8 +975,11 @@ def prompt_from_tags(
|
|
1123 |
copyright_txt = f"from {format_nl_list(tags)}"
|
1124 |
else:
|
1125 |
copyright_txt = ""
|
|
|
|
|
1126 |
tag_string = ", ".join(tp[1] for tp in other_tags)
|
1127 |
|
|
|
1128 |
image_pos = base_prompt.find("image")
|
1129 |
if image_pos < 0:
|
1130 |
raise ValueError("Base prompt must contain the word 'image'")
|
@@ -1132,20 +987,19 @@ def prompt_from_tags(
|
|
1132 |
base_prompt_prefix = base_prompt[:image_pos].rstrip()
|
1133 |
base_prompt_suffix = base_prompt[image_pos:].lstrip()
|
1134 |
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
)
|
1149 |
return custom_prompt
|
1150 |
|
1151 |
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
"""
|
5 |
+
JoyCaption Alpha Two
|
6 |
|
7 |
This module provides functionality for generating captions for images using a
|
8 |
combination of CLIP, LLM, and custom image adapters. It supports various
|
|
|
34 |
)
|
35 |
from torch import nn
|
36 |
from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
|
37 |
+
import logging
|
38 |
|
39 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
40 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
41 |
CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808"
|
42 |
CAPTION_TYPE_MAP = {
|
43 |
+
"descriptive": [
|
44 |
"Write a descriptive caption for this image in a formal tone.",
|
45 |
"Write a descriptive caption for this image in a formal tone within {word_count} words.",
|
46 |
"Write a {length} descriptive caption for this image in a formal tone.",
|
47 |
],
|
48 |
+
"descriptive (informal)": [
|
49 |
"Write a descriptive caption for this image in a casual tone.",
|
50 |
"Write a descriptive caption for this image in a casual tone within {word_count} words.",
|
51 |
"Write a {length} descriptive caption for this image in a casual tone.",
|
52 |
],
|
53 |
+
"training prompt": [
|
54 |
"Write a stable diffusion prompt for this image.",
|
55 |
"Write a stable diffusion prompt for this image within {word_count} words.",
|
56 |
"Write a {length} stable diffusion prompt for this image.",
|
57 |
],
|
58 |
+
"midjourney": [
|
59 |
"Write a MidJourney prompt for this image.",
|
60 |
"Write a MidJourney prompt for this image within {word_count} words.",
|
61 |
"Write a {length} MidJourney prompt for this image.",
|
62 |
],
|
63 |
+
"booru tag list": [
|
64 |
"Write a list of Booru tags for this image.",
|
65 |
"Write a list of Booru tags for this image within {word_count} words.",
|
66 |
"Write a {length} list of Booru tags for this image.",
|
67 |
],
|
68 |
+
"booru-like tag list": [
|
69 |
"Write a list of Booru-like tags for this image.",
|
70 |
"Write a list of Booru-like tags for this image within {word_count} words.",
|
71 |
"Write a {length} list of Booru-like tags for this image.",
|
72 |
],
|
73 |
+
"art critic": [
|
74 |
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
|
75 |
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
|
76 |
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
|
77 |
],
|
78 |
+
"product listing": [
|
79 |
"Write a caption for this image as though it were a product listing.",
|
80 |
"Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
|
81 |
"Write a {length} caption for this image as though it were a product listing.",
|
82 |
],
|
83 |
+
"social media post": [
|
84 |
"Write a caption for this image as if it were being used for a social media post.",
|
85 |
"Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
|
86 |
"Write a {length} caption for this image as if it were being used for a social media post.",
|
|
|
209 |
).squeeze(0)
|
210 |
|
211 |
|
212 |
+
STOP_WORDS: set[str] = set(
|
213 |
+
"i'll if we'd can't you'd shouldn't i'd only doesn't further isn't didn't has more aren't during do than were he's too here you against could few for ought won't we until weren't i've they're same up she but are how here's their over can under mustn't while on by had and an each he'd he about she'd am was she'll where's did out or that's it they'd a let's shall what's the to don't when below no any some from is hadn't all they i'm must in before who's own where you've that very them this not because it's shan't wasn't you'll when's most off i at other hasn't nor been such again we'll down above will so should into she's once have these why's be we've as being why those then with after may you're would haven't both wouldn't there cannot they've couldn't how's between does we're through he'll of there's they'll might".split(
|
214 |
+
" "
|
215 |
+
)
|
216 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
|
219 |
class JoyCaptionModel:
|
|
|
250 |
"""
|
251 |
Load and initialize all required models (CLIP, LLM, image adapter).
|
252 |
"""
|
253 |
+
logging.info("Loading CLIP")
|
254 |
self.clip_model = AutoModel.from_pretrained(CLIP_PATH)
|
255 |
self.clip_model = self.clip_model.vision_model
|
256 |
|
257 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
258 |
+
logging.info("Loading VLM's custom vision model")
|
259 |
checkpoint = torch.load(
|
260 |
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
|
261 |
)
|
|
|
269 |
self.clip_model.requires_grad_(False)
|
270 |
self.clip_model.to("cuda")
|
271 |
|
272 |
+
logging.info("Loading tokenizer")
|
273 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
|
274 |
assert isinstance(
|
275 |
self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
|
276 |
)
|
277 |
|
278 |
+
logging.info("Loading LLM")
|
279 |
if (CHECKPOINT_PATH / "text_model").exists():
|
280 |
+
logging.info("Loading VLM's custom text model")
|
281 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
282 |
CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
|
283 |
)
|
|
|
288 |
|
289 |
self.text_model.eval()
|
290 |
|
291 |
+
logging.info("Loading image adapter")
|
292 |
self.image_adapter = ImageAdapter(
|
293 |
self.clip_model.config.hidden_size,
|
294 |
self.text_model.config.hidden_size,
|
|
|
307 |
def process_image(
|
308 |
self,
|
309 |
input_image: Image.Image,
|
310 |
+
prompt_str: str,
|
|
|
|
|
311 |
) -> Tuple[str, float]:
|
312 |
"""
|
313 |
Process an input image and generate a caption based on specified parameters.
|
|
|
317 |
Tuple[str, float]: The generated caption and its entropy.
|
318 |
"""
|
319 |
torch.cuda.empty_cache()
|
320 |
+
logging.info(f"Prompt: {prompt_str}")
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
pixel_values = self._preprocess_image(input_image)
|
323 |
|
|
|
338 |
def generate_valid_caption(
|
339 |
self,
|
340 |
input_image: Image.Image,
|
341 |
+
prompt: str,
|
|
|
|
|
342 |
*,
|
343 |
limited_words: Dict[str, int] = {"fluffy": 2},
|
344 |
min_sentence_count: int = 3,
|
|
|
351 |
|
352 |
Args:
|
353 |
input_image (Image.Image): The input image to caption.
|
354 |
+
prompt (str | None): Prompt for caption generation.
|
|
|
|
|
355 |
limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 1}.
|
356 |
min_sentence_count (int): Minimum required number of sentences. Default is 3.
|
357 |
max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 15.
|
|
|
369 |
- The entropy of the caption is below min_entropy
|
370 |
"""
|
371 |
while True:
|
372 |
+
caption, entropy = self.process_image(input_image, prompt)
|
|
|
|
|
373 |
words = re.findall(r"\b\w+\b", caption.lower())
|
374 |
word_counts = {
|
375 |
word: words.count(word) for word in set(words) if word not in stop_words
|
|
|
377 |
sentence_count = len(re.findall(r"[.!?]", caption))
|
378 |
|
379 |
if not re.search(r"\w", caption):
|
380 |
+
logging.info(
|
381 |
f"Retrying: Caption contains only special characters.\nCaption: {caption!r}"
|
382 |
)
|
383 |
elif caption[-1] not in {".", "!", "?"}:
|
384 |
+
logging.info(
|
385 |
f"Retrying: Caption does not end with proper punctuation.\nCaption: {caption!r}"
|
386 |
)
|
387 |
elif any(
|
|
|
393 |
for word, max_count in limited_words.items()
|
394 |
if caption.lower().count(word) > max_count
|
395 |
]
|
396 |
+
logging.info(
|
397 |
f"Retrying: Limited words exceeded: {', '.join(exceeded_words)}.\nCaption: {caption!r}"
|
398 |
)
|
399 |
elif any(
|
|
|
406 |
for word, count in word_counts.items()
|
407 |
if count > max_word_repetitions and len(word) > 4
|
408 |
]
|
409 |
+
logging.info(
|
410 |
f"Retrying: Words repeated more than {max_word_repetitions} times: {', '.join(repeated_words)}.\nCaption: {caption!r}"
|
411 |
)
|
412 |
elif sentence_count < min_sentence_count:
|
413 |
+
logging.info(
|
414 |
f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}"
|
415 |
)
|
416 |
elif entropy < min_entropy:
|
417 |
+
logging.info(
|
418 |
f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}"
|
419 |
)
|
420 |
else:
|
421 |
return caption
|
422 |
|
423 |
+
@staticmethod
|
424 |
+
def get_prompt_string(caption_type, caption_length):
|
425 |
length = None if caption_length == "any" else caption_length
|
426 |
|
427 |
if isinstance(length, str):
|
|
|
440 |
else:
|
441 |
raise ValueError(f"Invalid caption length: {length}")
|
442 |
|
443 |
+
caption_type = caption_type.lower()
|
444 |
if caption_type not in CAPTION_TYPE_MAP:
|
445 |
raise ValueError(f"Invalid caption type: {caption_type}")
|
446 |
|
447 |
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
|
448 |
+
prompt_str = prompt_str.format(length=caption_length, word_count=caption_length)
|
449 |
return prompt_str
|
450 |
|
451 |
+
@staticmethod
|
452 |
+
def _preprocess_image(input_image: Image.Image) -> torch.Tensor:
|
453 |
"""
|
454 |
Preprocess the input image for the CLIP model.
|
455 |
|
|
|
504 |
convo_string = self.tokenizer.apply_chat_template(
|
505 |
convo, tokenize=False, add_generation_prompt=True
|
506 |
)
|
507 |
+
logging.debug(f"Convo:\n{convo_string}")
|
508 |
convo_tokens = self.tokenizer.encode(
|
509 |
convo_string,
|
510 |
return_tensors="pt",
|
|
|
558 |
input_ids,
|
559 |
inputs_embeds=inputs_embeds,
|
560 |
attention_mask=attention_mask,
|
561 |
+
max_new_tokens=512,
|
562 |
do_sample=True,
|
563 |
suppress_tokens=None,
|
564 |
repetition_penalty=1.2,
|
|
|
602 |
return entropy
|
603 |
|
604 |
|
605 |
+
class ColoredFormatter(logging.Formatter):
|
606 |
+
COLORS = {
|
607 |
+
"DEBUG": "\033[36m", # Cyan
|
608 |
+
"INFO": "\033[32m", # Green
|
609 |
+
"WARNING": "\033[33m", # Yellow
|
610 |
+
"ERROR": "\033[31m", # Red
|
611 |
+
"CRITICAL": "\033[31;1m", # Bright Red
|
612 |
+
}
|
613 |
+
RESET = "\033[0m"
|
614 |
+
|
615 |
+
def format(self, record):
|
616 |
+
log_message = super().format(record)
|
617 |
+
return f"{self.COLORS.get(record.levelname, '')}{log_message}{self.RESET}"
|
618 |
+
|
619 |
+
|
620 |
+
def setup_logging(verbosity):
|
621 |
+
if verbosity == 0:
|
622 |
+
log_level = logging.INFO
|
623 |
+
elif verbosity == 1:
|
624 |
+
log_level = logging.DEBUG
|
625 |
+
|
626 |
+
handler = logging.StreamHandler()
|
627 |
+
formatter = ColoredFormatter(
|
628 |
+
fmt="%(asctime)s | %(levelname)-8s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
629 |
+
)
|
630 |
+
handler.setFormatter(formatter)
|
631 |
+
|
632 |
+
logger = logging.getLogger()
|
633 |
+
logger.setLevel(log_level)
|
634 |
+
logger.addHandler(handler)
|
635 |
+
|
636 |
+
|
637 |
def main():
|
638 |
"""
|
639 |
Generate captions for images in a directory
|
|
|
652 |
"--caption_type",
|
653 |
type=str,
|
654 |
default="descriptive",
|
655 |
+
choices=CAPTION_TYPE_MAP.keys(),
|
656 |
help="Type of caption to generate.",
|
657 |
)
|
658 |
parser.add_argument(
|
|
|
692 |
"Only works if --feed-from-tags is enabled."
|
693 |
),
|
694 |
)
|
695 |
+
parser.add_argument(
|
696 |
+
"--dry-run",
|
697 |
+
action="store_true",
|
698 |
+
help="Run in dry-run mode without loading models or generating captions.",
|
699 |
+
)
|
700 |
+
parser.add_argument(
|
701 |
+
"-v",
|
702 |
+
"--verbose",
|
703 |
+
action="count",
|
704 |
+
default=0,
|
705 |
+
help="Increase output verbosity (can be repeated)",
|
706 |
+
)
|
707 |
|
708 |
args = parser.parse_args()
|
709 |
|
710 |
+
setup_logging(args.verbose)
|
711 |
+
|
712 |
# Validate random-tags usage
|
713 |
if args.random_tags is not None and args.feed_from_tags is None:
|
714 |
parser.error("--random-tags can only be used when --feed-from-tags is enabled")
|
715 |
|
716 |
+
if args.feed_from_tags is not None:
|
717 |
+
logging.info("Loading e621 tag data")
|
718 |
+
tagset_normalizer = make_tagset_normalizer()
|
719 |
|
720 |
+
# Initialize and load models only if not in dry-run mode
|
721 |
+
if not args.dry_run:
|
722 |
+
joy_caption_model = JoyCaptionModel()
|
723 |
+
joy_caption_model.load_models()
|
724 |
+
else:
|
725 |
+
logging.info("Running in dry-run mode. Models will not be loaded.")
|
|
|
|
|
|
|
726 |
|
727 |
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
|
728 |
for image_path in Path(args.directory).rglob("*"):
|
|
|
731 |
|
732 |
# Skip if the caption file already exists
|
733 |
if caption_file.exists():
|
734 |
+
logging.info(f"Skipping {image_path}: Caption file already exists.")
|
735 |
continue
|
736 |
|
737 |
+
if not args.dry_run:
|
738 |
+
input_image = Image.open(image_path).convert("RGB")
|
739 |
|
740 |
# Use custom prompt if specified
|
741 |
+
prompt = args.custom_prompt or JoyCaptionModel.get_prompt_string(
|
742 |
+
args.caption_type, args.caption_length
|
743 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
|
745 |
+
if args.feed_from_tags is not None:
|
746 |
+
prompt = prompt_from_tags(args, image_path, tagset_normalizer, prompt)
|
747 |
|
748 |
+
if args.dry_run:
|
749 |
+
logging.info(
|
750 |
+
f"Dry run: Skipping caption generation for {image_path} with prompt:\n\t{prompt}"
|
751 |
+
)
|
752 |
+
continue
|
753 |
+
|
754 |
+
caption = joy_caption_model.generate_valid_caption(input_image, prompt)
|
755 |
|
756 |
# Strip commas if the --dont-strip-commas flag is not set
|
757 |
if not args.dont_strip_commas:
|
|
|
765 |
# Remove all newline characters
|
766 |
caption = caption.replace("\n", " ")
|
767 |
|
768 |
+
logging.info(f"Caption for {image_path}:\n\t{caption}\n\n")
|
769 |
|
770 |
# Save the caption to a .caption file
|
771 |
with open(caption_file, "w", encoding="utf-8") as f:
|
772 |
f.write(caption)
|
773 |
+
logging.info(f"Caption saved to {caption_file}")
|
774 |
|
775 |
|
776 |
RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
|
|
|
847 |
TAG_ARTIST = tag_category2id["artist"]
|
848 |
TAG_COPYRIGHT = tag_category2id["copyright"]
|
849 |
TAG_META = tag_category2id["meta"]
|
|
|
850 |
|
851 |
|
852 |
def prompt_from_tags(
|
|
|
854 |
image_path: Path,
|
855 |
tagset_normalizer: TagSetNormalizer,
|
856 |
base_prompt: str = "Write a descriptive caption for this image in a formal tone.",
|
857 |
+
tag_freq_threshold: int = 0,
|
858 |
+
tag_string_prefix: str = "Use these tags to construct your caption:",
|
859 |
):
|
860 |
"""
|
861 |
Generates a prompt from tags associated with the given image.
|
|
|
866 |
The path to the image file.
|
867 |
tagset_normalizer (TagSetNormalizer):
|
868 |
An instance to normalize the tag set.
|
|
|
|
|
|
|
869 |
"""
|
870 |
+
# Find and read the corresponding tag file
|
871 |
tag_file = find_tag_file(image_path)
|
872 |
if tag_file is None:
|
873 |
+
logging.warning(f"No tag file found for {image_path}")
|
874 |
+
return base_prompt
|
875 |
|
876 |
with open(tag_file, "r", encoding="utf-8") as f:
|
877 |
tags = f.read().lower().split(",")
|
878 |
|
879 |
+
# Get helper functions from the tagset_normalizer
|
880 |
tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
|
881 |
encode = tagset_normalizer.tag_normalizer.encode
|
882 |
|
883 |
+
# Initialize dictionaries and lists to store categorized tags
|
884 |
+
# These lists will contain tuples (freq, tag, tag_id)
|
885 |
tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
|
886 |
cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
|
887 |
}
|
888 |
other_tags: List[Tuple[int, str, int]] = []
|
889 |
implied: set = set()
|
890 |
+
|
891 |
+
# Process each tag
|
892 |
for tag in tags:
|
893 |
tag = tag.strip()
|
894 |
# Encode the tag into a numerical id
|
895 |
tag_id = encode(tag.replace(" ", "_"))
|
896 |
if tag_id is None:
|
897 |
+
# If tag is not recognized, add it to other_tags
|
898 |
other_tags.append((0, tag, 0))
|
899 |
implied.update(tagset_normalizer.implications_rej.get(0, ()))
|
900 |
continue
|
|
|
903 |
# Skip meta tags
|
904 |
if cat_id == TAG_META:
|
905 |
continue
|
906 |
+
# Update implied tags
|
907 |
implied.update(tagset_normalizer.implications.get(tag_id, ()))
|
908 |
# Get the frequency of the tag
|
909 |
freq = tag_rank_to_freq(tag_id)
|
910 |
+
if freq < tag_freq_threshold:
|
911 |
continue
|
912 |
+
# Add the tag to its category, or other_tags
|
913 |
tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id))
|
914 |
|
915 |
+
# Sort other_tags by frequency (descending) and filter out implied tags
|
916 |
other_tags = sorted(
|
917 |
+
(-freq, tag, tag_id)
|
918 |
for freq, tag, tag_id in other_tags
|
919 |
if tag_id not in implied
|
920 |
)
|
921 |
|
922 |
+
# Sort tags within each category, prefering non implied tags
|
923 |
for cat_id, cat_list in tag_by_category.items():
|
924 |
tag_by_category[cat_id] = sorted(
|
925 |
+
((tag_id in implied, -freq), tag, tag_id) for freq, tag, tag_id in cat_list
|
|
|
|
|
926 |
)
|
927 |
|
928 |
+
# Handle random tag selection or tag limit if specified
|
929 |
if args.random_tags is not None:
|
930 |
# Randomly select tags if --random-tags is specified
|
931 |
num_tags = min(args.random_tags, len(other_tags))
|
|
|
940 |
# Use specified number of tags if --feed-from-tags has a positive value
|
941 |
other_tags = other_tags[: args.feed_from_tags]
|
942 |
|
943 |
+
# Prepare sentence pieces for each category
|
944 |
artist_tag = tag_by_category[TAG_ARTIST]
|
945 |
if artist_tag:
|
946 |
artist_list = [str(tp[1]).removeprefix("by ") for tp in artist_tag[:4]]
|
|
|
963 |
species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
|
964 |
else:
|
965 |
if character_tag:
|
966 |
+
species_txt = (
|
967 |
+
"of a character" if len(character_tag) <= 1 else "of characters"
|
968 |
+
)
|
969 |
else:
|
970 |
species_txt = ""
|
971 |
|
|
|
975 |
copyright_txt = f"from {format_nl_list(tags)}"
|
976 |
else:
|
977 |
copyright_txt = ""
|
978 |
+
|
979 |
+
# Prepare the remaining tags as a string
|
980 |
tag_string = ", ".join(tp[1] for tp in other_tags)
|
981 |
|
982 |
+
# Extract the prefix and suffix around the word "image" from the base prompt
|
983 |
image_pos = base_prompt.find("image")
|
984 |
if image_pos < 0:
|
985 |
raise ValueError("Base prompt must contain the word 'image'")
|
|
|
987 |
base_prompt_prefix = base_prompt[:image_pos].rstrip()
|
988 |
base_prompt_suffix = base_prompt[image_pos:].lstrip()
|
989 |
|
990 |
+
pieces = [
|
991 |
+
base_prompt_prefix,
|
992 |
+
artist_txt,
|
993 |
+
species_txt,
|
994 |
+
character_txt,
|
995 |
+
copyright_txt,
|
996 |
+
base_prompt_suffix,
|
997 |
+
tag_string_prefix,
|
998 |
+
tag_string,
|
999 |
+
]
|
1000 |
+
logging.debug("Prompt pieces: %r", pieces)
|
1001 |
+
custom_prompt = " ".join(p for p in pieces if p)
|
1002 |
+
custom_prompt = custom_prompt.replace(" .", ".").replace(" ,", ",")
|
|
|
1003 |
return custom_prompt
|
1004 |
|
1005 |
|