k4d3 commited on
Commit
913d039
·
1 Parent(s): 3461363

joy: some fixes and niceties

Browse files
Files changed (1) hide show
  1. joy +153 -299
joy CHANGED
@@ -2,7 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  """
5
- JoyCaption Alpha One
6
 
7
  This module provides functionality for generating captions for images using a
8
  combination of CLIP, LLM, and custom image adapters. It supports various
@@ -34,52 +34,53 @@ from transformers import (
34
  )
35
  from torch import nn
36
  from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
 
37
 
38
  CLIP_PATH = "google/siglip-so400m-patch14-384"
39
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
40
  CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808"
41
  CAPTION_TYPE_MAP = {
42
- "Descriptive": [
43
  "Write a descriptive caption for this image in a formal tone.",
44
  "Write a descriptive caption for this image in a formal tone within {word_count} words.",
45
  "Write a {length} descriptive caption for this image in a formal tone.",
46
  ],
47
- "Descriptive (Informal)": [
48
  "Write a descriptive caption for this image in a casual tone.",
49
  "Write a descriptive caption for this image in a casual tone within {word_count} words.",
50
  "Write a {length} descriptive caption for this image in a casual tone.",
51
  ],
52
- "Training Prompt": [
53
  "Write a stable diffusion prompt for this image.",
54
  "Write a stable diffusion prompt for this image within {word_count} words.",
55
  "Write a {length} stable diffusion prompt for this image.",
56
  ],
57
- "MidJourney": [
58
  "Write a MidJourney prompt for this image.",
59
  "Write a MidJourney prompt for this image within {word_count} words.",
60
  "Write a {length} MidJourney prompt for this image.",
61
  ],
62
- "Booru tag list": [
63
  "Write a list of Booru tags for this image.",
64
  "Write a list of Booru tags for this image within {word_count} words.",
65
  "Write a {length} list of Booru tags for this image.",
66
  ],
67
- "Booru-like tag list": [
68
  "Write a list of Booru-like tags for this image.",
69
  "Write a list of Booru-like tags for this image within {word_count} words.",
70
  "Write a {length} list of Booru-like tags for this image.",
71
  ],
72
- "Art Critic": [
73
  "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
74
  "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
75
  "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
76
  ],
77
- "Product Listing": [
78
  "Write a caption for this image as though it were a product listing.",
79
  "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
80
  "Write a {length} caption for this image as though it were a product listing.",
81
  ],
82
- "Social Media Post": [
83
  "Write a caption for this image as if it were being used for a social media post.",
84
  "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
85
  "Write a {length} caption for this image as if it were being used for a social media post.",
@@ -208,202 +209,11 @@ class ImageAdapter(nn.Module):
208
  ).squeeze(0)
209
 
210
 
211
- STOP_WORDS: set[str] = {
212
- "the",
213
- "a",
214
- "an",
215
- "and",
216
- "or",
217
- "but",
218
- "in",
219
- "on",
220
- "at",
221
- "to",
222
- "for",
223
- "of",
224
- "with",
225
- "by",
226
- "from",
227
- "up",
228
- "down",
229
- "is",
230
- "are",
231
- "was",
232
- "were",
233
- "be",
234
- "been",
235
- "being",
236
- "have",
237
- "has",
238
- "had",
239
- "do",
240
- "does",
241
- "did",
242
- "will",
243
- "would",
244
- "shall",
245
- "should",
246
- "can",
247
- "could",
248
- "may",
249
- "might",
250
- "must",
251
- "ought",
252
- "i",
253
- "you",
254
- "he",
255
- "she",
256
- "it",
257
- "we",
258
- "they",
259
- "them",
260
- "their",
261
- "this",
262
- "that",
263
- "these",
264
- "those",
265
- "am",
266
- "is",
267
- "are",
268
- "was",
269
- "were",
270
- "be",
271
- "been",
272
- "being",
273
- "have",
274
- "has",
275
- "had",
276
- "do",
277
- "does",
278
- "did",
279
- "will",
280
- "would",
281
- "shall",
282
- "should",
283
- "can",
284
- "could",
285
- "may",
286
- "might",
287
- "must",
288
- "ought",
289
- "i'm",
290
- "you're",
291
- "he's",
292
- "she's",
293
- "it's",
294
- "we're",
295
- "they're",
296
- "i've",
297
- "you've",
298
- "we've",
299
- "they've",
300
- "i'd",
301
- "you'd",
302
- "he'd",
303
- "she'd",
304
- "we'd",
305
- "they'd",
306
- "i'll",
307
- "you'll",
308
- "he'll",
309
- "she'll",
310
- "we'll",
311
- "they'll",
312
- "isn't",
313
- "aren't",
314
- "wasn't",
315
- "weren't",
316
- "hasn't",
317
- "haven't",
318
- "hadn't",
319
- "doesn't",
320
- "don't",
321
- "didn't",
322
- "won't",
323
- "wouldn't",
324
- "shan't",
325
- "shouldn't",
326
- "can't",
327
- "cannot",
328
- "couldn't",
329
- "mustn't",
330
- "let's",
331
- "that's",
332
- "who's",
333
- "what's",
334
- "here's",
335
- "there's",
336
- "when's",
337
- "where's",
338
- "why's",
339
- "how's",
340
- "a",
341
- "an",
342
- "the",
343
- "and",
344
- "but",
345
- "if",
346
- "or",
347
- "because",
348
- "as",
349
- "until",
350
- "while",
351
- "of",
352
- "at",
353
- "by",
354
- "for",
355
- "with",
356
- "about",
357
- "against",
358
- "between",
359
- "into",
360
- "through",
361
- "during",
362
- "before",
363
- "after",
364
- "above",
365
- "below",
366
- "to",
367
- "from",
368
- "up",
369
- "down",
370
- "in",
371
- "out",
372
- "on",
373
- "off",
374
- "over",
375
- "under",
376
- "again",
377
- "further",
378
- "then",
379
- "once",
380
- "here",
381
- "there",
382
- "when",
383
- "where",
384
- "why",
385
- "how",
386
- "all",
387
- "any",
388
- "both",
389
- "each",
390
- "few",
391
- "more",
392
- "most",
393
- "other",
394
- "some",
395
- "such",
396
- "no",
397
- "nor",
398
- "not",
399
- "only",
400
- "own",
401
- "same",
402
- "so",
403
- "than",
404
- "too",
405
- "very",
406
- }
407
 
408
 
409
  class JoyCaptionModel:
@@ -440,12 +250,12 @@ class JoyCaptionModel:
440
  """
441
  Load and initialize all required models (CLIP, LLM, image adapter).
442
  """
443
- print("Loading CLIP")
444
  self.clip_model = AutoModel.from_pretrained(CLIP_PATH)
445
  self.clip_model = self.clip_model.vision_model
446
 
447
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
448
- print("Loading VLM's custom vision model")
449
  checkpoint = torch.load(
450
  CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
451
  )
@@ -459,15 +269,15 @@ class JoyCaptionModel:
459
  self.clip_model.requires_grad_(False)
460
  self.clip_model.to("cuda")
461
 
462
- print("Loading tokenizer")
463
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
464
  assert isinstance(
465
  self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
466
  )
467
 
468
- print("Loading LLM")
469
  if (CHECKPOINT_PATH / "text_model").exists():
470
- print("Loading VLM's custom text model")
471
  self.text_model = AutoModelForCausalLM.from_pretrained(
472
  CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
473
  )
@@ -478,7 +288,7 @@ class JoyCaptionModel:
478
 
479
  self.text_model.eval()
480
 
481
- print("Loading image adapter")
482
  self.image_adapter = ImageAdapter(
483
  self.clip_model.config.hidden_size,
484
  self.text_model.config.hidden_size,
@@ -497,9 +307,7 @@ class JoyCaptionModel:
497
  def process_image(
498
  self,
499
  input_image: Image.Image,
500
- caption_type: str,
501
- caption_length: str | int,
502
- custom_prompt: str | None = None,
503
  ) -> Tuple[str, float]:
504
  """
505
  Process an input image and generate a caption based on specified parameters.
@@ -509,12 +317,7 @@ class JoyCaptionModel:
509
  Tuple[str, float]: The generated caption and its entropy.
510
  """
511
  torch.cuda.empty_cache()
512
-
513
- if custom_prompt is not None:
514
- prompt_str = custom_prompt
515
- else:
516
- prompt_str = self._get_prompt_string(caption_type, caption_length)
517
- print(f"Prompt: {prompt_str}")
518
 
519
  pixel_values = self._preprocess_image(input_image)
520
 
@@ -535,9 +338,7 @@ class JoyCaptionModel:
535
  def generate_valid_caption(
536
  self,
537
  input_image: Image.Image,
538
- caption_type: str,
539
- caption_length: str | int,
540
- custom_prompt: str | None = None,
541
  *,
542
  limited_words: Dict[str, int] = {"fluffy": 2},
543
  min_sentence_count: int = 3,
@@ -550,9 +351,7 @@ class JoyCaptionModel:
550
 
551
  Args:
552
  input_image (Image.Image): The input image to caption.
553
- caption_type (str): The type of caption to generate.
554
- caption_length (str | int): The desired length of the caption.
555
- custom_prompt (str | None): A custom prompt for caption generation.
556
  limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 1}.
557
  min_sentence_count (int): Minimum required number of sentences. Default is 3.
558
  max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 15.
@@ -570,9 +369,7 @@ class JoyCaptionModel:
570
  - The entropy of the caption is below min_entropy
571
  """
572
  while True:
573
- caption, entropy = self.process_image(
574
- input_image, caption_type, caption_length, custom_prompt
575
- )
576
  words = re.findall(r"\b\w+\b", caption.lower())
577
  word_counts = {
578
  word: words.count(word) for word in set(words) if word not in stop_words
@@ -580,11 +377,11 @@ class JoyCaptionModel:
580
  sentence_count = len(re.findall(r"[.!?]", caption))
581
 
582
  if not re.search(r"\w", caption):
583
- print(
584
  f"Retrying: Caption contains only special characters.\nCaption: {caption!r}"
585
  )
586
  elif caption[-1] not in {".", "!", "?"}:
587
- print(
588
  f"Retrying: Caption does not end with proper punctuation.\nCaption: {caption!r}"
589
  )
590
  elif any(
@@ -596,7 +393,7 @@ class JoyCaptionModel:
596
  for word, max_count in limited_words.items()
597
  if caption.lower().count(word) > max_count
598
  ]
599
- print(
600
  f"Retrying: Limited words exceeded: {', '.join(exceeded_words)}.\nCaption: {caption!r}"
601
  )
602
  elif any(
@@ -609,21 +406,22 @@ class JoyCaptionModel:
609
  for word, count in word_counts.items()
610
  if count > max_word_repetitions and len(word) > 4
611
  ]
612
- print(
613
  f"Retrying: Words repeated more than {max_word_repetitions} times: {', '.join(repeated_words)}.\nCaption: {caption!r}"
614
  )
615
  elif sentence_count < min_sentence_count:
616
- print(
617
  f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}"
618
  )
619
  elif entropy < min_entropy:
620
- print(
621
  f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}"
622
  )
623
  else:
624
  return caption
625
 
626
- def _get_prompt_string(self, caption_type, caption_length):
 
627
  length = None if caption_length == "any" else caption_length
628
 
629
  if isinstance(length, str):
@@ -642,13 +440,16 @@ class JoyCaptionModel:
642
  else:
643
  raise ValueError(f"Invalid caption length: {length}")
644
 
 
645
  if caption_type not in CAPTION_TYPE_MAP:
646
  raise ValueError(f"Invalid caption type: {caption_type}")
647
 
648
  prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
 
649
  return prompt_str
650
 
651
- def _preprocess_image(self, input_image: Image.Image) -> torch.Tensor:
 
652
  """
653
  Preprocess the input image for the CLIP model.
654
 
@@ -703,6 +504,7 @@ class JoyCaptionModel:
703
  convo_string = self.tokenizer.apply_chat_template(
704
  convo, tokenize=False, add_generation_prompt=True
705
  )
 
706
  convo_tokens = self.tokenizer.encode(
707
  convo_string,
708
  return_tensors="pt",
@@ -756,7 +558,7 @@ class JoyCaptionModel:
756
  input_ids,
757
  inputs_embeds=inputs_embeds,
758
  attention_mask=attention_mask,
759
- max_new_tokens=300,
760
  do_sample=True,
761
  suppress_tokens=None,
762
  repetition_penalty=1.2,
@@ -800,6 +602,38 @@ class JoyCaptionModel:
800
  return entropy
801
 
802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
  def main():
804
  """
805
  Generate captions for images in a directory
@@ -818,7 +652,7 @@ def main():
818
  "--caption_type",
819
  type=str,
820
  default="descriptive",
821
- choices=["descriptive", "training_prompt", "rng-tags", "custom"],
822
  help="Type of caption to generate.",
823
  )
824
  parser.add_argument(
@@ -858,25 +692,37 @@ def main():
858
  "Only works if --feed-from-tags is enabled."
859
  ),
860
  )
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
  args = parser.parse_args()
863
 
 
 
864
  # Validate random-tags usage
865
  if args.random_tags is not None and args.feed_from_tags is None:
866
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
867
 
868
- print("Loading e621 tag data")
869
- tagset_normalizer = make_tagset_normalizer()
 
870
 
871
- # Initialize and load models
872
- joy_caption_model = JoyCaptionModel()
873
- joy_caption_model.load_models()
874
-
875
- # Validate custom prompt usage
876
- if args.caption_type == "custom" and not args.custom_prompt:
877
- parser.error("--custom_prompt is required when using --caption_type custom")
878
- elif args.caption_type != "custom" and args.custom_prompt:
879
- parser.error("--custom_prompt can only be used with --caption_type custom")
880
 
881
  image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
882
  for image_path in Path(args.directory).rglob("*"):
@@ -885,31 +731,27 @@ def main():
885
 
886
  # Skip if the caption file already exists
887
  if caption_file.exists():
888
- print(f"Skipping {image_path}: Caption file already exists.")
889
  continue
890
 
891
- input_image = Image.open(image_path).convert("RGB")
 
892
 
893
  # Use custom prompt if specified
894
- custom_prompt = None
895
- if args.caption_type == "custom":
896
- custom_prompt = args.custom_prompt
897
- elif args.feed_from_tags is not None:
898
- base_prompt = joy_caption_model._get_prompt_string(
899
- args.caption_type, args.caption_length
900
- )
901
- custom_prompt = prompt_from_tags(
902
- args, image_path, tagset_normalizer, base_prompt
903
- )
904
 
905
- print(f"\nCustom prompt: {custom_prompt}")
 
906
 
907
- caption = joy_caption_model.generate_valid_caption(
908
- input_image,
909
- args.caption_type,
910
- args.caption_length,
911
- custom_prompt=custom_prompt,
912
- )
 
913
 
914
  # Strip commas if the --dont-strip-commas flag is not set
915
  if not args.dont_strip_commas:
@@ -923,12 +765,12 @@ def main():
923
  # Remove all newline characters
924
  caption = caption.replace("\n", " ")
925
 
926
- print(f"Caption for {image_path}:\n\n{caption}\n\n")
927
 
928
  # Save the caption to a .caption file
929
  with open(caption_file, "w", encoding="utf-8") as f:
930
  f.write(caption)
931
- print(f"Caption saved to {caption_file}")
932
 
933
 
934
  RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
@@ -1005,7 +847,6 @@ TAG_CHARACTER = tag_category2id["character"]
1005
  TAG_ARTIST = tag_category2id["artist"]
1006
  TAG_COPYRIGHT = tag_category2id["copyright"]
1007
  TAG_META = tag_category2id["meta"]
1008
- TAG_FREQ_THRESH = 0
1009
 
1010
 
1011
  def prompt_from_tags(
@@ -1013,6 +854,8 @@ def prompt_from_tags(
1013
  image_path: Path,
1014
  tagset_normalizer: TagSetNormalizer,
1015
  base_prompt: str = "Write a descriptive caption for this image in a formal tone.",
 
 
1016
  ):
1017
  """
1018
  Generates a prompt from tags associated with the given image.
@@ -1023,31 +866,35 @@ def prompt_from_tags(
1023
  The path to the image file.
1024
  tagset_normalizer (TagSetNormalizer):
1025
  An instance to normalize the tag set.
1026
-
1027
- Returns:
1028
- None
1029
  """
 
1030
  tag_file = find_tag_file(image_path)
1031
  if tag_file is None:
1032
- return None
 
1033
 
1034
  with open(tag_file, "r", encoding="utf-8") as f:
1035
  tags = f.read().lower().split(",")
1036
 
 
1037
  tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
1038
  encode = tagset_normalizer.tag_normalizer.encode
1039
 
1040
- # These lists contain tuples (freq, tag, tag_id)
 
1041
  tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
1042
  cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
1043
  }
1044
  other_tags: List[Tuple[int, str, int]] = []
1045
  implied: set = set()
 
 
1046
  for tag in tags:
1047
  tag = tag.strip()
1048
  # Encode the tag into a numerical id
1049
  tag_id = encode(tag.replace(" ", "_"))
1050
  if tag_id is None:
 
1051
  other_tags.append((0, tag, 0))
1052
  implied.update(tagset_normalizer.implications_rej.get(0, ()))
1053
  continue
@@ -1056,26 +903,29 @@ def prompt_from_tags(
1056
  # Skip meta tags
1057
  if cat_id == TAG_META:
1058
  continue
 
1059
  implied.update(tagset_normalizer.implications.get(tag_id, ()))
1060
  # Get the frequency of the tag
1061
  freq = tag_rank_to_freq(tag_id)
1062
- if freq < TAG_FREQ_THRESH:
1063
  continue
 
1064
  tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id))
1065
 
 
1066
  other_tags = sorted(
1067
- (int(freq), tag, tag_id)
1068
  for freq, tag, tag_id in other_tags
1069
  if tag_id not in implied
1070
  )
1071
 
 
1072
  for cat_id, cat_list in tag_by_category.items():
1073
  tag_by_category[cat_id] = sorted(
1074
- (int(freq), tag, tag_id)
1075
- for freq, tag, tag_id in cat_list
1076
- if tag_id not in implied
1077
  )
1078
 
 
1079
  if args.random_tags is not None:
1080
  # Randomly select tags if --random-tags is specified
1081
  num_tags = min(args.random_tags, len(other_tags))
@@ -1090,7 +940,7 @@ def prompt_from_tags(
1090
  # Use specified number of tags if --feed-from-tags has a positive value
1091
  other_tags = other_tags[: args.feed_from_tags]
1092
 
1093
- # Prepare sentence pieces
1094
  artist_tag = tag_by_category[TAG_ARTIST]
1095
  if artist_tag:
1096
  artist_list = [str(tp[1]).removeprefix("by ") for tp in artist_tag[:4]]
@@ -1113,7 +963,9 @@ def prompt_from_tags(
1113
  species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
1114
  else:
1115
  if character_tag:
1116
- species_txt = " a character" if len(character_tag) <= 1 else " characters"
 
 
1117
  else:
1118
  species_txt = ""
1119
 
@@ -1123,8 +975,11 @@ def prompt_from_tags(
1123
  copyright_txt = f"from {format_nl_list(tags)}"
1124
  else:
1125
  copyright_txt = ""
 
 
1126
  tag_string = ", ".join(tp[1] for tp in other_tags)
1127
 
 
1128
  image_pos = base_prompt.find("image")
1129
  if image_pos < 0:
1130
  raise ValueError("Base prompt must contain the word 'image'")
@@ -1132,20 +987,19 @@ def prompt_from_tags(
1132
  base_prompt_prefix = base_prompt[:image_pos].rstrip()
1133
  base_prompt_suffix = base_prompt[image_pos:].lstrip()
1134
 
1135
- custom_prompt = " ".join(
1136
- s
1137
- for s in [
1138
- base_prompt_prefix,
1139
- artist_txt,
1140
- species_txt,
1141
- character_txt,
1142
- copyright_txt,
1143
- base_prompt_suffix,
1144
- "Use these tags to construct your caption:",
1145
- tag_string,
1146
- ]
1147
- if s
1148
- )
1149
  return custom_prompt
1150
 
1151
 
 
2
  # -*- coding: utf-8 -*-
3
 
4
  """
5
+ JoyCaption Alpha Two
6
 
7
  This module provides functionality for generating captions for images using a
8
  combination of CLIP, LLM, and custom image adapters. It supports various
 
34
  )
35
  from torch import nn
36
  from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
37
+ import logging
38
 
39
  CLIP_PATH = "google/siglip-so400m-patch14-384"
40
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
41
  CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808"
42
  CAPTION_TYPE_MAP = {
43
+ "descriptive": [
44
  "Write a descriptive caption for this image in a formal tone.",
45
  "Write a descriptive caption for this image in a formal tone within {word_count} words.",
46
  "Write a {length} descriptive caption for this image in a formal tone.",
47
  ],
48
+ "descriptive (informal)": [
49
  "Write a descriptive caption for this image in a casual tone.",
50
  "Write a descriptive caption for this image in a casual tone within {word_count} words.",
51
  "Write a {length} descriptive caption for this image in a casual tone.",
52
  ],
53
+ "training prompt": [
54
  "Write a stable diffusion prompt for this image.",
55
  "Write a stable diffusion prompt for this image within {word_count} words.",
56
  "Write a {length} stable diffusion prompt for this image.",
57
  ],
58
+ "midjourney": [
59
  "Write a MidJourney prompt for this image.",
60
  "Write a MidJourney prompt for this image within {word_count} words.",
61
  "Write a {length} MidJourney prompt for this image.",
62
  ],
63
+ "booru tag list": [
64
  "Write a list of Booru tags for this image.",
65
  "Write a list of Booru tags for this image within {word_count} words.",
66
  "Write a {length} list of Booru tags for this image.",
67
  ],
68
+ "booru-like tag list": [
69
  "Write a list of Booru-like tags for this image.",
70
  "Write a list of Booru-like tags for this image within {word_count} words.",
71
  "Write a {length} list of Booru-like tags for this image.",
72
  ],
73
+ "art critic": [
74
  "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
75
  "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
76
  "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
77
  ],
78
+ "product listing": [
79
  "Write a caption for this image as though it were a product listing.",
80
  "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
81
  "Write a {length} caption for this image as though it were a product listing.",
82
  ],
83
+ "social media post": [
84
  "Write a caption for this image as if it were being used for a social media post.",
85
  "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
86
  "Write a {length} caption for this image as if it were being used for a social media post.",
 
209
  ).squeeze(0)
210
 
211
 
212
+ STOP_WORDS: set[str] = set(
213
+ "i'll if we'd can't you'd shouldn't i'd only doesn't further isn't didn't has more aren't during do than were he's too here you against could few for ought won't we until weren't i've they're same up she but are how here's their over can under mustn't while on by had and an each he'd he about she'd am was she'll where's did out or that's it they'd a let's shall what's the to don't when below no any some from is hadn't all they i'm must in before who's own where you've that very them this not because it's shan't wasn't you'll when's most off i at other hasn't nor been such again we'll down above will so should into she's once have these why's be we've as being why those then with after may you're would haven't both wouldn't there cannot they've couldn't how's between does we're through he'll of there's they'll might".split(
214
+ " "
215
+ )
216
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
 
219
  class JoyCaptionModel:
 
250
  """
251
  Load and initialize all required models (CLIP, LLM, image adapter).
252
  """
253
+ logging.info("Loading CLIP")
254
  self.clip_model = AutoModel.from_pretrained(CLIP_PATH)
255
  self.clip_model = self.clip_model.vision_model
256
 
257
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
258
+ logging.info("Loading VLM's custom vision model")
259
  checkpoint = torch.load(
260
  CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
261
  )
 
269
  self.clip_model.requires_grad_(False)
270
  self.clip_model.to("cuda")
271
 
272
+ logging.info("Loading tokenizer")
273
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
274
  assert isinstance(
275
  self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
276
  )
277
 
278
+ logging.info("Loading LLM")
279
  if (CHECKPOINT_PATH / "text_model").exists():
280
+ logging.info("Loading VLM's custom text model")
281
  self.text_model = AutoModelForCausalLM.from_pretrained(
282
  CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
283
  )
 
288
 
289
  self.text_model.eval()
290
 
291
+ logging.info("Loading image adapter")
292
  self.image_adapter = ImageAdapter(
293
  self.clip_model.config.hidden_size,
294
  self.text_model.config.hidden_size,
 
307
  def process_image(
308
  self,
309
  input_image: Image.Image,
310
+ prompt_str: str,
 
 
311
  ) -> Tuple[str, float]:
312
  """
313
  Process an input image and generate a caption based on specified parameters.
 
317
  Tuple[str, float]: The generated caption and its entropy.
318
  """
319
  torch.cuda.empty_cache()
320
+ logging.info(f"Prompt: {prompt_str}")
 
 
 
 
 
321
 
322
  pixel_values = self._preprocess_image(input_image)
323
 
 
338
  def generate_valid_caption(
339
  self,
340
  input_image: Image.Image,
341
+ prompt: str,
 
 
342
  *,
343
  limited_words: Dict[str, int] = {"fluffy": 2},
344
  min_sentence_count: int = 3,
 
351
 
352
  Args:
353
  input_image (Image.Image): The input image to caption.
354
+ prompt (str | None): Prompt for caption generation.
 
 
355
  limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 1}.
356
  min_sentence_count (int): Minimum required number of sentences. Default is 3.
357
  max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 15.
 
369
  - The entropy of the caption is below min_entropy
370
  """
371
  while True:
372
+ caption, entropy = self.process_image(input_image, prompt)
 
 
373
  words = re.findall(r"\b\w+\b", caption.lower())
374
  word_counts = {
375
  word: words.count(word) for word in set(words) if word not in stop_words
 
377
  sentence_count = len(re.findall(r"[.!?]", caption))
378
 
379
  if not re.search(r"\w", caption):
380
+ logging.info(
381
  f"Retrying: Caption contains only special characters.\nCaption: {caption!r}"
382
  )
383
  elif caption[-1] not in {".", "!", "?"}:
384
+ logging.info(
385
  f"Retrying: Caption does not end with proper punctuation.\nCaption: {caption!r}"
386
  )
387
  elif any(
 
393
  for word, max_count in limited_words.items()
394
  if caption.lower().count(word) > max_count
395
  ]
396
+ logging.info(
397
  f"Retrying: Limited words exceeded: {', '.join(exceeded_words)}.\nCaption: {caption!r}"
398
  )
399
  elif any(
 
406
  for word, count in word_counts.items()
407
  if count > max_word_repetitions and len(word) > 4
408
  ]
409
+ logging.info(
410
  f"Retrying: Words repeated more than {max_word_repetitions} times: {', '.join(repeated_words)}.\nCaption: {caption!r}"
411
  )
412
  elif sentence_count < min_sentence_count:
413
+ logging.info(
414
  f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}"
415
  )
416
  elif entropy < min_entropy:
417
+ logging.info(
418
  f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}"
419
  )
420
  else:
421
  return caption
422
 
423
+ @staticmethod
424
+ def get_prompt_string(caption_type, caption_length):
425
  length = None if caption_length == "any" else caption_length
426
 
427
  if isinstance(length, str):
 
440
  else:
441
  raise ValueError(f"Invalid caption length: {length}")
442
 
443
+ caption_type = caption_type.lower()
444
  if caption_type not in CAPTION_TYPE_MAP:
445
  raise ValueError(f"Invalid caption type: {caption_type}")
446
 
447
  prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
448
+ prompt_str = prompt_str.format(length=caption_length, word_count=caption_length)
449
  return prompt_str
450
 
451
+ @staticmethod
452
+ def _preprocess_image(input_image: Image.Image) -> torch.Tensor:
453
  """
454
  Preprocess the input image for the CLIP model.
455
 
 
504
  convo_string = self.tokenizer.apply_chat_template(
505
  convo, tokenize=False, add_generation_prompt=True
506
  )
507
+ logging.debug(f"Convo:\n{convo_string}")
508
  convo_tokens = self.tokenizer.encode(
509
  convo_string,
510
  return_tensors="pt",
 
558
  input_ids,
559
  inputs_embeds=inputs_embeds,
560
  attention_mask=attention_mask,
561
+ max_new_tokens=512,
562
  do_sample=True,
563
  suppress_tokens=None,
564
  repetition_penalty=1.2,
 
602
  return entropy
603
 
604
 
605
+ class ColoredFormatter(logging.Formatter):
606
+ COLORS = {
607
+ "DEBUG": "\033[36m", # Cyan
608
+ "INFO": "\033[32m", # Green
609
+ "WARNING": "\033[33m", # Yellow
610
+ "ERROR": "\033[31m", # Red
611
+ "CRITICAL": "\033[31;1m", # Bright Red
612
+ }
613
+ RESET = "\033[0m"
614
+
615
+ def format(self, record):
616
+ log_message = super().format(record)
617
+ return f"{self.COLORS.get(record.levelname, '')}{log_message}{self.RESET}"
618
+
619
+
620
+ def setup_logging(verbosity):
621
+ if verbosity == 0:
622
+ log_level = logging.INFO
623
+ elif verbosity == 1:
624
+ log_level = logging.DEBUG
625
+
626
+ handler = logging.StreamHandler()
627
+ formatter = ColoredFormatter(
628
+ fmt="%(asctime)s | %(levelname)-8s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
629
+ )
630
+ handler.setFormatter(formatter)
631
+
632
+ logger = logging.getLogger()
633
+ logger.setLevel(log_level)
634
+ logger.addHandler(handler)
635
+
636
+
637
  def main():
638
  """
639
  Generate captions for images in a directory
 
652
  "--caption_type",
653
  type=str,
654
  default="descriptive",
655
+ choices=CAPTION_TYPE_MAP.keys(),
656
  help="Type of caption to generate.",
657
  )
658
  parser.add_argument(
 
692
  "Only works if --feed-from-tags is enabled."
693
  ),
694
  )
695
+ parser.add_argument(
696
+ "--dry-run",
697
+ action="store_true",
698
+ help="Run in dry-run mode without loading models or generating captions.",
699
+ )
700
+ parser.add_argument(
701
+ "-v",
702
+ "--verbose",
703
+ action="count",
704
+ default=0,
705
+ help="Increase output verbosity (can be repeated)",
706
+ )
707
 
708
  args = parser.parse_args()
709
 
710
+ setup_logging(args.verbose)
711
+
712
  # Validate random-tags usage
713
  if args.random_tags is not None and args.feed_from_tags is None:
714
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
715
 
716
+ if args.feed_from_tags is not None:
717
+ logging.info("Loading e621 tag data")
718
+ tagset_normalizer = make_tagset_normalizer()
719
 
720
+ # Initialize and load models only if not in dry-run mode
721
+ if not args.dry_run:
722
+ joy_caption_model = JoyCaptionModel()
723
+ joy_caption_model.load_models()
724
+ else:
725
+ logging.info("Running in dry-run mode. Models will not be loaded.")
 
 
 
726
 
727
  image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
728
  for image_path in Path(args.directory).rglob("*"):
 
731
 
732
  # Skip if the caption file already exists
733
  if caption_file.exists():
734
+ logging.info(f"Skipping {image_path}: Caption file already exists.")
735
  continue
736
 
737
+ if not args.dry_run:
738
+ input_image = Image.open(image_path).convert("RGB")
739
 
740
  # Use custom prompt if specified
741
+ prompt = args.custom_prompt or JoyCaptionModel.get_prompt_string(
742
+ args.caption_type, args.caption_length
743
+ )
 
 
 
 
 
 
 
744
 
745
+ if args.feed_from_tags is not None:
746
+ prompt = prompt_from_tags(args, image_path, tagset_normalizer, prompt)
747
 
748
+ if args.dry_run:
749
+ logging.info(
750
+ f"Dry run: Skipping caption generation for {image_path} with prompt:\n\t{prompt}"
751
+ )
752
+ continue
753
+
754
+ caption = joy_caption_model.generate_valid_caption(input_image, prompt)
755
 
756
  # Strip commas if the --dont-strip-commas flag is not set
757
  if not args.dont_strip_commas:
 
765
  # Remove all newline characters
766
  caption = caption.replace("\n", " ")
767
 
768
+ logging.info(f"Caption for {image_path}:\n\t{caption}\n\n")
769
 
770
  # Save the caption to a .caption file
771
  with open(caption_file, "w", encoding="utf-8") as f:
772
  f.write(caption)
773
+ logging.info(f"Caption saved to {caption_file}")
774
 
775
 
776
  RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
 
847
  TAG_ARTIST = tag_category2id["artist"]
848
  TAG_COPYRIGHT = tag_category2id["copyright"]
849
  TAG_META = tag_category2id["meta"]
 
850
 
851
 
852
  def prompt_from_tags(
 
854
  image_path: Path,
855
  tagset_normalizer: TagSetNormalizer,
856
  base_prompt: str = "Write a descriptive caption for this image in a formal tone.",
857
+ tag_freq_threshold: int = 0,
858
+ tag_string_prefix: str = "Use these tags to construct your caption:",
859
  ):
860
  """
861
  Generates a prompt from tags associated with the given image.
 
866
  The path to the image file.
867
  tagset_normalizer (TagSetNormalizer):
868
  An instance to normalize the tag set.
 
 
 
869
  """
870
+ # Find and read the corresponding tag file
871
  tag_file = find_tag_file(image_path)
872
  if tag_file is None:
873
+ logging.warning(f"No tag file found for {image_path}")
874
+ return base_prompt
875
 
876
  with open(tag_file, "r", encoding="utf-8") as f:
877
  tags = f.read().lower().split(",")
878
 
879
+ # Get helper functions from the tagset_normalizer
880
  tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
881
  encode = tagset_normalizer.tag_normalizer.encode
882
 
883
+ # Initialize dictionaries and lists to store categorized tags
884
+ # These lists will contain tuples (freq, tag, tag_id)
885
  tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
886
  cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
887
  }
888
  other_tags: List[Tuple[int, str, int]] = []
889
  implied: set = set()
890
+
891
+ # Process each tag
892
  for tag in tags:
893
  tag = tag.strip()
894
  # Encode the tag into a numerical id
895
  tag_id = encode(tag.replace(" ", "_"))
896
  if tag_id is None:
897
+ # If tag is not recognized, add it to other_tags
898
  other_tags.append((0, tag, 0))
899
  implied.update(tagset_normalizer.implications_rej.get(0, ()))
900
  continue
 
903
  # Skip meta tags
904
  if cat_id == TAG_META:
905
  continue
906
+ # Update implied tags
907
  implied.update(tagset_normalizer.implications.get(tag_id, ()))
908
  # Get the frequency of the tag
909
  freq = tag_rank_to_freq(tag_id)
910
+ if freq < tag_freq_threshold:
911
  continue
912
+ # Add the tag to its category, or other_tags
913
  tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id))
914
 
915
+ # Sort other_tags by frequency (descending) and filter out implied tags
916
  other_tags = sorted(
917
+ (-freq, tag, tag_id)
918
  for freq, tag, tag_id in other_tags
919
  if tag_id not in implied
920
  )
921
 
922
+ # Sort tags within each category, prefering non implied tags
923
  for cat_id, cat_list in tag_by_category.items():
924
  tag_by_category[cat_id] = sorted(
925
+ ((tag_id in implied, -freq), tag, tag_id) for freq, tag, tag_id in cat_list
 
 
926
  )
927
 
928
+ # Handle random tag selection or tag limit if specified
929
  if args.random_tags is not None:
930
  # Randomly select tags if --random-tags is specified
931
  num_tags = min(args.random_tags, len(other_tags))
 
940
  # Use specified number of tags if --feed-from-tags has a positive value
941
  other_tags = other_tags[: args.feed_from_tags]
942
 
943
+ # Prepare sentence pieces for each category
944
  artist_tag = tag_by_category[TAG_ARTIST]
945
  if artist_tag:
946
  artist_list = [str(tp[1]).removeprefix("by ") for tp in artist_tag[:4]]
 
963
  species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
964
  else:
965
  if character_tag:
966
+ species_txt = (
967
+ "of a character" if len(character_tag) <= 1 else "of characters"
968
+ )
969
  else:
970
  species_txt = ""
971
 
 
975
  copyright_txt = f"from {format_nl_list(tags)}"
976
  else:
977
  copyright_txt = ""
978
+
979
+ # Prepare the remaining tags as a string
980
  tag_string = ", ".join(tp[1] for tp in other_tags)
981
 
982
+ # Extract the prefix and suffix around the word "image" from the base prompt
983
  image_pos = base_prompt.find("image")
984
  if image_pos < 0:
985
  raise ValueError("Base prompt must contain the word 'image'")
 
987
  base_prompt_prefix = base_prompt[:image_pos].rstrip()
988
  base_prompt_suffix = base_prompt[image_pos:].lstrip()
989
 
990
+ pieces = [
991
+ base_prompt_prefix,
992
+ artist_txt,
993
+ species_txt,
994
+ character_txt,
995
+ copyright_txt,
996
+ base_prompt_suffix,
997
+ tag_string_prefix,
998
+ tag_string,
999
+ ]
1000
+ logging.debug("Prompt pieces: %r", pieces)
1001
+ custom_prompt = " ".join(p for p in pieces if p)
1002
+ custom_prompt = custom_prompt.replace(" .", ".").replace(" ,", ",")
 
1003
  return custom_prompt
1004
 
1005