k4d3 commited on
Commit
2e83eb5
·
1 Parent(s): c45b591

update joy and crawl UwU

Browse files
Files changed (2) hide show
  1. crawl/crawl +13 -6
  2. joy +205 -132
crawl/crawl CHANGED
@@ -16,7 +16,7 @@ import platform
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import time
18
  import argparse
19
- from urllib.parse import urljoin
20
  import requests
21
  try:
22
  from crawl4ai import WebCrawler # type: ignore
@@ -75,11 +75,18 @@ def download_image(session, image_url, save_dir, base_url):
75
  The base URL of the page being crawled.
76
  """
77
  try:
78
- # Ensure the URL has a scheme
79
- if image_url.startswith(".."):
80
- image_url = urljoin(base_url, image_url)
81
- elif not re.match(r"^https?://", image_url):
82
- image_url = "https://" + image_url.lstrip("/")
 
 
 
 
 
 
 
83
 
84
  image_filename = os.path.basename(image_url).split("?")[0]
85
  sanitized_image_filename = sanitize_filename(image_filename)
 
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import time
18
  import argparse
19
+ from urllib.parse import urljoin, urlparse
20
  import requests
21
  try:
22
  from crawl4ai import WebCrawler # type: ignore
 
75
  The base URL of the page being crawled.
76
  """
77
  try:
78
+ # Parse the base URL to get the scheme and netloc
79
+ parsed_base_url = urlparse(base_url)
80
+ base_image_url = (
81
+ f"{parsed_base_url.scheme}://{parsed_base_url.netloc}/"
82
+ )
83
+
84
+ # Ensure the URL has a scheme and is properly joined with
85
+ # the base image URL
86
+ if not re.match(r"^https?://", image_url):
87
+ image_url = urljoin(
88
+ base_image_url, image_url.lstrip("/")
89
+ )
90
 
91
  image_filename = os.path.basename(image_url).split("?")[0]
92
  sanitized_image_filename = sanitize_filename(image_filename)
joy CHANGED
@@ -18,7 +18,6 @@ import os
18
  import argparse
19
  import re
20
  import random
21
- from collections import Counter
22
  from pathlib import Path
23
  from PIL import Image
24
  import pillow_jxl
@@ -26,14 +25,14 @@ import torch
26
  import torchvision.transforms.functional as TVF
27
  from transformers import (
28
  AutoModel,
29
- AutoProcessor,
30
  AutoTokenizer,
31
  AutoModelForCausalLM,
32
  PreTrainedTokenizer,
33
  PreTrainedTokenizerFast,
34
  )
35
  from torch import nn
36
- from e6db_reader import TagNormalizer, TagSetNormalizer, tag_category2id, tag_rank_to_freq
 
37
 
38
  CLIP_PATH = "google/siglip-so400m-patch14-384"
39
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
@@ -63,8 +62,7 @@ CAPTION_TYPE_MAP = {
63
  "Write a stable diffusion prompt for this image."
64
  ],
65
  ("training_prompt", "formal", False, True): [
66
- "Write a stable diffusion prompt for this image within {word_count} "
67
- "words."
68
  ],
69
  ("training_prompt", "formal", True, False): [
70
  "Write a {length} stable diffusion prompt for this image."
@@ -82,6 +80,7 @@ CAPTION_TYPE_MAP = {
82
 
83
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
84
 
 
85
  class ImageAdapter(nn.Module):
86
  """
87
  Custom image adapter module for processing CLIP vision outputs.
@@ -118,8 +117,10 @@ class ImageAdapter(nn.Module):
118
  self.activation = nn.GELU()
119
  self.linear2 = nn.Linear(output_features, output_features)
120
  self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
121
- self.pos_emb = None if not pos_emb else nn.Parameter(
122
- torch.zeros(num_image_tokens, input_features)
 
 
123
  )
124
 
125
  self.other_tokens = nn.Embedding(3, output_features)
@@ -136,26 +137,29 @@ class ImageAdapter(nn.Module):
136
  torch.Tensor: Adapted image features.
137
  """
138
  if self.deep_extract:
139
- x = torch.concat((
140
- vision_outputs[-2],
141
- vision_outputs[3],
142
- vision_outputs[7],
143
- vision_outputs[13],
144
- vision_outputs[20],
145
- ), dim=-1)
146
- assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
147
- assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, (
148
- f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
149
  )
 
 
 
 
150
  else:
151
  x = vision_outputs[-2]
152
 
153
  x = self.ln1(x)
154
 
155
  if self.pos_emb is not None:
156
- assert x.shape[-2:] == self.pos_emb.shape, (
157
- f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
158
- )
159
  x = x + self.pos_emb
160
 
161
  x = self.linear1(x)
@@ -167,9 +171,11 @@ class ImageAdapter(nn.Module):
167
  x.shape[0], -1
168
  )
169
  )
170
- assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), (
171
- f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
172
- )
 
 
173
  x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
174
 
175
  return x
@@ -185,6 +191,7 @@ class ImageAdapter(nn.Module):
185
  torch.tensor([2], device=self.other_tokens.weight.device)
186
  ).squeeze(0)
187
 
 
188
  class JoyCaptionModel:
189
  """
190
  A class for generating captions for images using CLIP, LLM, and custom image adapters.
@@ -221,8 +228,12 @@ class JoyCaptionModel:
221
 
222
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
223
  print("Loading VLM's custom vision model")
224
- checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
225
- checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
 
 
 
 
226
  self.clip_model.load_state_dict(checkpoint)
227
  del checkpoint
228
 
@@ -240,15 +251,11 @@ class JoyCaptionModel:
240
  if (CHECKPOINT_PATH / "text_model").exists():
241
  print("Loading VLM's custom text model")
242
  self.text_model = AutoModelForCausalLM.from_pretrained(
243
- CHECKPOINT_PATH / "text_model",
244
- device_map=0,
245
- torch_dtype=torch.bfloat16
246
  )
247
  else:
248
  self.text_model = AutoModelForCausalLM.from_pretrained(
249
- MODEL_PATH,
250
- device_map="auto",
251
- torch_dtype=torch.bfloat16
252
  )
253
 
254
  self.text_model.eval()
@@ -260,7 +267,7 @@ class JoyCaptionModel:
260
  False,
261
  False,
262
  38,
263
- False
264
  )
265
  self.image_adapter.load_state_dict(
266
  torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
@@ -269,12 +276,14 @@ class JoyCaptionModel:
269
  self.image_adapter.to("cuda")
270
 
271
  @torch.no_grad()
272
- def process_image(self,
273
- input_image: Image.Image,
274
- caption_type: str,
275
- caption_tone: str,
276
- caption_length: str | int,
277
- custom_prompt: str | None = None) -> str:
 
 
278
  """
279
  Process an input image and generate a caption based on specified parameters.
280
  """
@@ -283,14 +292,18 @@ class JoyCaptionModel:
283
  if custom_prompt is not None:
284
  prompt_str = custom_prompt
285
  else:
286
- prompt_str = self._get_prompt_string(caption_type, caption_tone, caption_length)
 
 
287
  print(f"Prompt: {prompt_str}")
288
 
289
  pixel_values = self._preprocess_image(input_image)
290
  prompt = self._tokenize_prompt(prompt_str)
291
 
292
  embedded_images = self._embed_image(pixel_values)
293
- inputs_embeds, input_ids, attention_mask = self._construct_inputs(embedded_images, prompt)
 
 
294
 
295
  generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask)
296
  caption = self._decode_caption(generate_ids, input_ids)
@@ -313,7 +326,7 @@ class JoyCaptionModel:
313
  caption_type,
314
  caption_tone,
315
  isinstance(length, str),
316
- isinstance(length, int)
317
  )
318
  if prompt_key not in CAPTION_TYPE_MAP:
319
  raise ValueError(f"Invalid caption type: {prompt_key}")
@@ -327,57 +340,73 @@ class JoyCaptionModel:
327
  image = input_image.resize((384, 384), Image.LANCZOS)
328
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
329
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
330
- pixel_values = pixel_values.to('cuda')
331
  return pixel_values
332
 
333
  def _tokenize_prompt(self, prompt_str):
334
  prompt = self.tokenizer.encode(
335
  prompt_str,
336
- return_tensors='pt',
337
  padding=False,
338
  truncation=False,
339
- add_special_tokens=False
340
  )
341
  return prompt
342
 
343
  def _embed_image(self, pixel_values):
344
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
345
- vision_outputs = self.clip_model(pixel_values=pixel_values, output_hidden_states=True)
 
 
346
  image_features = vision_outputs.hidden_states
347
  embedded_images = self.image_adapter(image_features)
348
- embedded_images = embedded_images.to('cuda')
349
  return embedded_images
350
 
351
  def _construct_inputs(self, embedded_images, prompt):
352
- prompt_embeds = self.text_model.model.embed_tokens(prompt.to('cuda'))
353
- assert prompt_embeds.shape == (1, prompt.shape[1], self.text_model.config.hidden_size), (
 
 
 
 
354
  f"Prompt shape is {prompt_embeds.shape}, expected "
355
  f"{(1, prompt.shape[1], self.text_model.config.hidden_size)}"
356
  )
357
 
358
  embedded_bos = self.text_model.model.embed_tokens(
359
- torch.tensor([[self.tokenizer.bos_token_id]],
360
- device=self.text_model.device,
361
- dtype=torch.int64)
 
 
 
 
 
 
 
 
362
  )
363
 
364
- eot_embed = self.image_adapter.get_eot_embedding().unsqueeze(0).to(
365
- dtype=self.text_model.dtype
 
 
 
 
 
 
366
  )
367
 
368
- inputs_embeds = torch.cat([
369
- embedded_bos.expand(embedded_images.shape[0], -1, -1),
370
- embedded_images.to(dtype=embedded_bos.dtype),
371
- prompt_embeds.expand(embedded_images.shape[0], -1, -1),
372
- eot_embed.expand(embedded_images.shape[0], -1, -1),
373
- ], dim=1)
374
-
375
- input_ids = torch.cat([
376
- torch.tensor([[self.tokenizer.bos_token_id]], dtype=torch.long),
377
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
378
- prompt,
379
- torch.tensor([[self.tokenizer.eos_token_id]], dtype=torch.long),
380
- ], dim=1).to('cuda')
381
  attention_mask = torch.ones_like(input_ids)
382
 
383
  return inputs_embeds, input_ids, attention_mask
@@ -389,21 +418,20 @@ class JoyCaptionModel:
389
  attention_mask=attention_mask,
390
  max_new_tokens=300,
391
  do_sample=True,
392
- suppress_tokens=None
393
  )
394
  return generate_ids
395
 
396
  def _decode_caption(self, generate_ids, input_ids):
397
- generate_ids = generate_ids[:, input_ids.shape[1]:]
398
 
399
- if (generate_ids[0][-1] == self.tokenizer.eos_token_id or
400
- generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>")):
 
401
  generate_ids = generate_ids[:, :-1]
402
 
403
  caption = self.tokenizer.batch_decode(
404
- generate_ids,
405
- skip_special_tokens=False,
406
- clean_up_tokenization_spaces=False
407
  )[0]
408
  return caption
409
 
@@ -413,53 +441,52 @@ def main():
413
  parser = argparse.ArgumentParser(
414
  description="Generate captions for images in a directory and save them as .caption files."
415
  )
416
- parser.add_argument("directory", type=str, help="Target directory containing images.")
 
 
417
  parser.add_argument(
418
  "--caption_type",
419
  type=str,
420
  default="descriptive",
421
  choices=["descriptive", "training_prompt", "rng-tags", "custom"],
422
- help="Type of caption to generate."
423
  )
424
  parser.add_argument(
425
  "--caption_tone",
426
  type=str,
427
  default="formal",
428
  choices=["formal", "informal"],
429
- help="Tone of the caption."
430
  )
431
  parser.add_argument(
432
- "--caption_length",
433
- type=str,
434
- default="any",
435
- help="Length of the caption."
436
  )
437
  parser.add_argument(
438
  "--dont-strip-commas",
439
  action="store_true",
440
- help="If set, commas will not be stripped from the generated captions."
441
  )
442
  parser.add_argument(
443
  "--custom_prompt",
444
  type=str,
445
- help="Custom prompt for the captioner. Use with --caption_type custom."
446
  )
447
  parser.add_argument(
448
- '--add-commas-to-sentence-ends',
449
- action='store_true',
450
- help='Add commas after periods in sentences'
451
  )
452
  parser.add_argument(
453
- '--feed-from-tags',
454
  type=int,
455
- nargs='?',
456
  const=-1,
457
- help='Use .txt files with the same base filename as the images as input to the captioner. Optionally specify the number of tags to use.'
458
  )
459
  parser.add_argument(
460
- '--random-tags',
461
  type=int,
462
- help='Randomly select n number of tags. Only works if --feed-from-tags is enabled.'
463
  )
464
 
465
  args = parser.parse_args()
@@ -468,7 +495,7 @@ def main():
468
  if args.random_tags is not None and args.feed_from_tags is None:
469
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
470
 
471
- print('Loading e621 tag data')
472
  tagset_normalizer = make_tagset_normalizer()
473
 
474
  # Initialize and load models
@@ -484,7 +511,7 @@ def main():
484
  image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
485
  for image_path in Path(args.directory).rglob("*"):
486
  if image_path.suffix.lower() in image_extensions:
487
- caption_file = image_path.with_suffix('.caption')
488
 
489
  # Skip if the caption file already exists
490
  if caption_file.exists():
@@ -501,29 +528,28 @@ def main():
501
  custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer)
502
 
503
  print(f"Custom prompt: {custom_prompt}")
504
- continue
505
 
506
  caption = joy_caption_model.process_image(
507
  input_image,
508
  args.caption_type,
509
  args.caption_tone,
510
  args.caption_length,
511
- custom_prompt=custom_prompt
512
  )
513
 
514
  # Strip commas if the --dont-strip-commas flag is not set
515
  if not args.dont_strip_commas:
516
  # Existing comma stripping logic
517
- caption = re.sub(r',\s*([^\d])', r' \1', caption)
518
 
519
  # New feature: Add commas after periods if specified
520
  if args.add_commas_to_sentence_ends:
521
- caption = re.sub(r'(\.)(\s+)([A-Z])', r'\1,\2\3', caption)
522
 
523
  print(f"Caption for {image_path}:\n\n{caption}\n\n")
524
 
525
  # Save the caption to a .caption file
526
- with open(caption_file, 'w', encoding='utf-8') as f:
527
  f.write(caption)
528
  print(f"Caption saved to {caption_file}")
529
 
@@ -531,6 +557,7 @@ def main():
531
  RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
532
  E6DB_DATA = Path(__file__).resolve().parent / "data"
533
 
 
534
  def make_tagset_normalizer():
535
  """
536
  Create a TagSetNormalizer for encoding/decoding tags to and from integers.
@@ -581,7 +608,6 @@ def make_tagset_normalizer():
581
  return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
582
 
583
 
584
-
585
  def format_nl_list(l):
586
  n = len(l)
587
  assert n > 0
@@ -589,36 +615,51 @@ def format_nl_list(l):
589
  return l[0]
590
  elif n == 2:
591
  return f"{l[0]} and {l[1]}"
592
- else: # n > 2
593
  *head, last = l
594
- return ', '.join(head) + ', and ' + last
595
 
596
- TAG_SPECIES = tag_category2id['species']
597
- TAG_CHARACTER = tag_category2id['character']
598
- TAG_ARTIST = tag_category2id['artist']
599
- TAG_COPYRIGHT = tag_category2id['copyright']
600
- TAG_META = tag_category2id['meta']
 
601
  TAG_FREQ_THRESH = 0
602
 
 
603
  def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer):
 
 
 
 
 
 
 
 
 
 
 
604
  tag_file = find_tag_file(image_path)
605
  if tag_file is None:
606
  return None
607
 
608
- with open(tag_file, 'r', encoding='utf-8') as f:
609
- tags = f.read().lower().split(',')
610
 
611
  tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
612
  encode = tagset_normalizer.tag_normalizer.encode
613
 
614
  # These lists contain tuples (freq, tag, tag_id)
615
- tag_by_category = {cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]}
616
- other_tags = []
617
- implied = set()
 
 
618
  for tag in tags:
619
  tag = tag.strip()
620
  # Encode the tag into a numerical id
621
- tag_id = encode(tag.replace(' ', '_'))
622
  if tag_id is None:
623
  other_tags.append((0, tag, None))
624
  implied.update(tagset_normalizer.implications_rej.get(tag_id, ()))
@@ -633,69 +674,101 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
633
  freq = tag_rank_to_freq(tag_id)
634
  if freq < TAG_FREQ_THRESH:
635
  continue
636
- tag_by_category.get(cat_id, other_tags).append((freq, tag, tag_id))
637
 
638
- other_tags = sorted((freq, tag) for freq, tag, tag_id in other_tags if tag_id not in implied)
 
 
 
 
639
  for cat_id, cat_list in tag_by_category.items():
640
- tag_by_category[cat_id] = sorted((freq, tag) for freq, tag, tag_id in cat_list if tag_id not in implied)
 
 
 
 
641
 
642
  if args.random_tags is not None:
643
  # Randomly select tags if --random-tags is specified
644
  num_tags = min(args.random_tags, len(other_tags))
645
- other_tags = random.sample(tags[:round(args.random_tags * 1.5)], num_tags)
 
 
 
 
 
 
646
  elif args.feed_from_tags > 0:
647
  # Use specified number of tags if --feed-from-tags has a positive value
648
- other_tags = other_tags[:args.feed_from_tags]
649
 
650
  # Prepare sentence pieces
651
  artist_tag = tag_by_category[TAG_ARTIST]
652
  if artist_tag:
653
- artist_txt = f' by {format_nl_list([tag.removeprefix('by ') for _, tag in artist_tag[:4]])}'
 
 
654
  else:
655
- artist_txt = ''
656
  character_tag = tag_by_category[TAG_CHARACTER]
657
  if character_tag:
658
- character_txt = f' named {format_nl_list([tag for _, tag in character_tag[:4]])}'
 
659
  else:
660
- character_txt = ''
661
  species_tag = tag_by_category[TAG_SPECIES]
662
  if species_tag:
663
- species_txt = f' of{" a" if len(character_tag) <= 1 and len(species_tag) <= 1 else ""} {format_nl_list([tag for _, tag in species_tag[:4]])}'
 
664
  else:
665
  if character_tag:
666
- species_txt = f' of{" a character" if len(character_tag) <= 1 else " characters"}'
 
 
 
 
667
  else:
668
- species_txt = ''
669
  copyright_tag = tag_by_category[TAG_COPYRIGHT]
670
  if copyright_tag:
671
- copyright_txt = f' from {format_nl_list([tag for _, tag in copyright_tag[:4]])}'
 
672
  else:
673
- copyright_txt = ''
674
-
675
- tag_string = ', '.join(tag for _, tag in other_tags)
676
- custom_prompt = f"Write a descriptive caption for this image{artist_txt}{species_txt}{character_txt}{copyright_txt} in a formal tone. Use these tags as context clues to construct your caption: {tag_string}"
 
 
 
 
 
 
 
677
  return custom_prompt
678
 
 
679
  def find_tag_file(image_path):
680
  """
681
  Find the corresponding .txt file for the given image path.
682
  Handles cases where the image has a -(number) suffix.
683
  """
684
  base_name = image_path.stem
685
- tag_file = image_path.with_suffix('.txt')
686
 
687
  if tag_file.exists():
688
  return tag_file
689
 
690
  # Handle -(number) suffix
691
- match = re.match(r'(.+)-\d+$', base_name)
692
  if match:
693
  base_name = match.group(1)
694
- tag_file = image_path.with_name(base_name).with_suffix('.txt')
695
  if tag_file.exists():
696
  return tag_file
697
 
698
  return None
699
 
 
700
  if __name__ == "__main__":
701
  main()
 
18
  import argparse
19
  import re
20
  import random
 
21
  from pathlib import Path
22
  from PIL import Image
23
  import pillow_jxl
 
25
  import torchvision.transforms.functional as TVF
26
  from transformers import (
27
  AutoModel,
 
28
  AutoTokenizer,
29
  AutoModelForCausalLM,
30
  PreTrainedTokenizer,
31
  PreTrainedTokenizerFast,
32
  )
33
  from torch import nn
34
+ from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
35
+ from typing import List, Tuple, Dict
36
 
37
  CLIP_PATH = "google/siglip-so400m-patch14-384"
38
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
 
62
  "Write a stable diffusion prompt for this image."
63
  ],
64
  ("training_prompt", "formal", False, True): [
65
+ "Write a stable diffusion prompt for this image within {word_count} " "words."
 
66
  ],
67
  ("training_prompt", "formal", True, False): [
68
  "Write a {length} stable diffusion prompt for this image."
 
80
 
81
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
82
 
83
+
84
  class ImageAdapter(nn.Module):
85
  """
86
  Custom image adapter module for processing CLIP vision outputs.
 
117
  self.activation = nn.GELU()
118
  self.linear2 = nn.Linear(output_features, output_features)
119
  self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
120
+ self.pos_emb = (
121
+ None
122
+ if not pos_emb
123
+ else nn.Parameter(torch.zeros(num_image_tokens, input_features))
124
  )
125
 
126
  self.other_tokens = nn.Embedding(3, output_features)
 
137
  torch.Tensor: Adapted image features.
138
  """
139
  if self.deep_extract:
140
+ x = torch.concat(
141
+ (
142
+ vision_outputs[-2],
143
+ vision_outputs[3],
144
+ vision_outputs[7],
145
+ vision_outputs[13],
146
+ vision_outputs[20],
147
+ ),
148
+ dim=-1,
 
149
  )
150
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
151
+ assert (
152
+ x.shape[-1] == vision_outputs[-2].shape[-1] * 5
153
+ ), f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
154
  else:
155
  x = vision_outputs[-2]
156
 
157
  x = self.ln1(x)
158
 
159
  if self.pos_emb is not None:
160
+ assert (
161
+ x.shape[-2:] == self.pos_emb.shape
162
+ ), f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
163
  x = x + self.pos_emb
164
 
165
  x = self.linear1(x)
 
171
  x.shape[0], -1
172
  )
173
  )
174
+ assert other_tokens.shape == (
175
+ x.shape[0],
176
+ 2,
177
+ x.shape[2],
178
+ ), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
179
  x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
180
 
181
  return x
 
191
  torch.tensor([2], device=self.other_tokens.weight.device)
192
  ).squeeze(0)
193
 
194
+
195
  class JoyCaptionModel:
196
  """
197
  A class for generating captions for images using CLIP, LLM, and custom image adapters.
 
228
 
229
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
230
  print("Loading VLM's custom vision model")
231
+ checkpoint = torch.load(
232
+ CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
233
+ )
234
+ checkpoint = {
235
+ k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()
236
+ }
237
  self.clip_model.load_state_dict(checkpoint)
238
  del checkpoint
239
 
 
251
  if (CHECKPOINT_PATH / "text_model").exists():
252
  print("Loading VLM's custom text model")
253
  self.text_model = AutoModelForCausalLM.from_pretrained(
254
+ CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
 
 
255
  )
256
  else:
257
  self.text_model = AutoModelForCausalLM.from_pretrained(
258
+ MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16
 
 
259
  )
260
 
261
  self.text_model.eval()
 
267
  False,
268
  False,
269
  38,
270
+ False,
271
  )
272
  self.image_adapter.load_state_dict(
273
  torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
 
276
  self.image_adapter.to("cuda")
277
 
278
  @torch.no_grad()
279
+ def process_image(
280
+ self,
281
+ input_image: Image.Image,
282
+ caption_type: str,
283
+ caption_tone: str,
284
+ caption_length: str | int,
285
+ custom_prompt: str | None = None,
286
+ ) -> str:
287
  """
288
  Process an input image and generate a caption based on specified parameters.
289
  """
 
292
  if custom_prompt is not None:
293
  prompt_str = custom_prompt
294
  else:
295
+ prompt_str = self._get_prompt_string(
296
+ caption_type, caption_tone, caption_length
297
+ )
298
  print(f"Prompt: {prompt_str}")
299
 
300
  pixel_values = self._preprocess_image(input_image)
301
  prompt = self._tokenize_prompt(prompt_str)
302
 
303
  embedded_images = self._embed_image(pixel_values)
304
+ inputs_embeds, input_ids, attention_mask = self._construct_inputs(
305
+ embedded_images, prompt
306
+ )
307
 
308
  generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask)
309
  caption = self._decode_caption(generate_ids, input_ids)
 
326
  caption_type,
327
  caption_tone,
328
  isinstance(length, str),
329
+ isinstance(length, int),
330
  )
331
  if prompt_key not in CAPTION_TYPE_MAP:
332
  raise ValueError(f"Invalid caption type: {prompt_key}")
 
340
  image = input_image.resize((384, 384), Image.LANCZOS)
341
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
342
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
343
+ pixel_values = pixel_values.to("cuda")
344
  return pixel_values
345
 
346
  def _tokenize_prompt(self, prompt_str):
347
  prompt = self.tokenizer.encode(
348
  prompt_str,
349
+ return_tensors="pt",
350
  padding=False,
351
  truncation=False,
352
+ add_special_tokens=False,
353
  )
354
  return prompt
355
 
356
  def _embed_image(self, pixel_values):
357
+ with torch.amp.autocast_mode.autocast("cuda", enabled=True):
358
+ vision_outputs = self.clip_model(
359
+ pixel_values=pixel_values, output_hidden_states=True
360
+ )
361
  image_features = vision_outputs.hidden_states
362
  embedded_images = self.image_adapter(image_features)
363
+ embedded_images = embedded_images.to("cuda")
364
  return embedded_images
365
 
366
  def _construct_inputs(self, embedded_images, prompt):
367
+ prompt_embeds = self.text_model.model.embed_tokens(prompt.to("cuda"))
368
+ assert prompt_embeds.shape == (
369
+ 1,
370
+ prompt.shape[1],
371
+ self.text_model.config.hidden_size,
372
+ ), (
373
  f"Prompt shape is {prompt_embeds.shape}, expected "
374
  f"{(1, prompt.shape[1], self.text_model.config.hidden_size)}"
375
  )
376
 
377
  embedded_bos = self.text_model.model.embed_tokens(
378
+ torch.tensor(
379
+ [[self.tokenizer.bos_token_id]],
380
+ device=self.text_model.device,
381
+ dtype=torch.int64,
382
+ )
383
+ )
384
+
385
+ eot_embed = (
386
+ self.image_adapter.get_eot_embedding()
387
+ .unsqueeze(0)
388
+ .to(dtype=self.text_model.dtype)
389
  )
390
 
391
+ inputs_embeds = torch.cat(
392
+ [
393
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
394
+ embedded_images.to(dtype=embedded_bos.dtype),
395
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
396
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
397
+ ],
398
+ dim=1,
399
  )
400
 
401
+ input_ids = torch.cat(
402
+ [
403
+ torch.tensor([[self.tokenizer.bos_token_id]], dtype=torch.long),
404
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
405
+ prompt,
406
+ torch.tensor([[self.tokenizer.eos_token_id]], dtype=torch.long),
407
+ ],
408
+ dim=1,
409
+ ).to("cuda")
 
 
 
 
410
  attention_mask = torch.ones_like(input_ids)
411
 
412
  return inputs_embeds, input_ids, attention_mask
 
418
  attention_mask=attention_mask,
419
  max_new_tokens=300,
420
  do_sample=True,
421
+ suppress_tokens=None,
422
  )
423
  return generate_ids
424
 
425
  def _decode_caption(self, generate_ids, input_ids):
426
+ generate_ids = generate_ids[:, input_ids.shape[1] :]
427
 
428
+ if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][
429
+ -1
430
+ ] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"):
431
  generate_ids = generate_ids[:, :-1]
432
 
433
  caption = self.tokenizer.batch_decode(
434
+ generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
 
 
435
  )[0]
436
  return caption
437
 
 
441
  parser = argparse.ArgumentParser(
442
  description="Generate captions for images in a directory and save them as .caption files."
443
  )
444
+ parser.add_argument(
445
+ "directory", type=str, help="Target directory containing images."
446
+ )
447
  parser.add_argument(
448
  "--caption_type",
449
  type=str,
450
  default="descriptive",
451
  choices=["descriptive", "training_prompt", "rng-tags", "custom"],
452
+ help="Type of caption to generate.",
453
  )
454
  parser.add_argument(
455
  "--caption_tone",
456
  type=str,
457
  default="formal",
458
  choices=["formal", "informal"],
459
+ help="Tone of the caption.",
460
  )
461
  parser.add_argument(
462
+ "--caption_length", type=str, default="any", help="Length of the caption."
 
 
 
463
  )
464
  parser.add_argument(
465
  "--dont-strip-commas",
466
  action="store_true",
467
+ help="If set, commas will not be stripped from the generated captions.",
468
  )
469
  parser.add_argument(
470
  "--custom_prompt",
471
  type=str,
472
+ help="Custom prompt for the captioner. Use with --caption_type custom.",
473
  )
474
  parser.add_argument(
475
+ "--add-commas-to-sentence-ends",
476
+ action="store_true",
477
+ help="Add commas after periods in sentences",
478
  )
479
  parser.add_argument(
480
+ "--feed-from-tags",
481
  type=int,
482
+ nargs="?",
483
  const=-1,
484
+ help="Use .txt files with the same base filename as the images as input to the captioner. Optionally specify the number of tags to use.",
485
  )
486
  parser.add_argument(
487
+ "--random-tags",
488
  type=int,
489
+ help="Randomly select n number of tags. Only works if --feed-from-tags is enabled.",
490
  )
491
 
492
  args = parser.parse_args()
 
495
  if args.random_tags is not None and args.feed_from_tags is None:
496
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
497
 
498
+ print("Loading e621 tag data")
499
  tagset_normalizer = make_tagset_normalizer()
500
 
501
  # Initialize and load models
 
511
  image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
512
  for image_path in Path(args.directory).rglob("*"):
513
  if image_path.suffix.lower() in image_extensions:
514
+ caption_file = image_path.with_suffix(".caption")
515
 
516
  # Skip if the caption file already exists
517
  if caption_file.exists():
 
528
  custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer)
529
 
530
  print(f"Custom prompt: {custom_prompt}")
 
531
 
532
  caption = joy_caption_model.process_image(
533
  input_image,
534
  args.caption_type,
535
  args.caption_tone,
536
  args.caption_length,
537
+ custom_prompt=custom_prompt,
538
  )
539
 
540
  # Strip commas if the --dont-strip-commas flag is not set
541
  if not args.dont_strip_commas:
542
  # Existing comma stripping logic
543
+ caption = re.sub(r",\s*([^\d])", r" \1", caption)
544
 
545
  # New feature: Add commas after periods if specified
546
  if args.add_commas_to_sentence_ends:
547
+ caption = re.sub(r"(\.)(\s+)([A-Z])", r"\1,\2\3", caption)
548
 
549
  print(f"Caption for {image_path}:\n\n{caption}\n\n")
550
 
551
  # Save the caption to a .caption file
552
+ with open(caption_file, "w", encoding="utf-8") as f:
553
  f.write(caption)
554
  print(f"Caption saved to {caption_file}")
555
 
 
557
  RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
558
  E6DB_DATA = Path(__file__).resolve().parent / "data"
559
 
560
+
561
  def make_tagset_normalizer():
562
  """
563
  Create a TagSetNormalizer for encoding/decoding tags to and from integers.
 
608
  return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
609
 
610
 
 
611
  def format_nl_list(l):
612
  n = len(l)
613
  assert n > 0
 
615
  return l[0]
616
  elif n == 2:
617
  return f"{l[0]} and {l[1]}"
618
+ else: # n > 2
619
  *head, last = l
620
+ return ", ".join(head) + ", and " + last
621
 
622
+
623
+ TAG_SPECIES = tag_category2id["species"]
624
+ TAG_CHARACTER = tag_category2id["character"]
625
+ TAG_ARTIST = tag_category2id["artist"]
626
+ TAG_COPYRIGHT = tag_category2id["copyright"]
627
+ TAG_META = tag_category2id["meta"]
628
  TAG_FREQ_THRESH = 0
629
 
630
+
631
  def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer):
632
+ """
633
+ Generates a prompt from tags associated with the given image.
634
+
635
+ Args:
636
+ args: Additional arguments for the function.
637
+ image_path (Path): The path to the image file.
638
+ tagset_normalizer (TagSetNormalizer): An instance to normalize the tag set.
639
+
640
+ Returns:
641
+ None
642
+ """
643
  tag_file = find_tag_file(image_path)
644
  if tag_file is None:
645
  return None
646
 
647
+ with open(tag_file, "r", encoding="utf-8") as f:
648
+ tags = f.read().lower().split(",")
649
 
650
  tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
651
  encode = tagset_normalizer.tag_normalizer.encode
652
 
653
  # These lists contain tuples (freq, tag, tag_id)
654
+ tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
655
+ cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
656
+ }
657
+ other_tags: List[Tuple[int, str, int]] = []
658
+ implied: set = set()
659
  for tag in tags:
660
  tag = tag.strip()
661
  # Encode the tag into a numerical id
662
+ tag_id = encode(tag.replace(" ", "_"))
663
  if tag_id is None:
664
  other_tags.append((0, tag, None))
665
  implied.update(tagset_normalizer.implications_rej.get(tag_id, ()))
 
674
  freq = tag_rank_to_freq(tag_id)
675
  if freq < TAG_FREQ_THRESH:
676
  continue
677
+ tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id))
678
 
679
+ other_tags = sorted(
680
+ (int(freq), tag, tag_id)
681
+ for freq, tag, tag_id in other_tags
682
+ if tag_id not in implied
683
+ )
684
  for cat_id, cat_list in tag_by_category.items():
685
+ tag_by_category[cat_id] = sorted(
686
+ (int(freq), tag, tag_id)
687
+ for freq, tag, tag_id in cat_list
688
+ if tag_id not in implied
689
+ )
690
 
691
  if args.random_tags is not None:
692
  # Randomly select tags if --random-tags is specified
693
  num_tags = min(args.random_tags, len(other_tags))
694
+ other_tags = random.sample(
695
+ [
696
+ (i, tag, tag_id)
697
+ for i, tag, tag_id in enumerate(tags[: round(args.random_tags * 1.5)])
698
+ ],
699
+ num_tags,
700
+ )
701
  elif args.feed_from_tags > 0:
702
  # Use specified number of tags if --feed-from-tags has a positive value
703
+ other_tags = other_tags[: args.feed_from_tags]
704
 
705
  # Prepare sentence pieces
706
  artist_tag = tag_by_category[TAG_ARTIST]
707
  if artist_tag:
708
+ artist_list = [str(tag).removeprefix('by ')
709
+ for *_, tag in artist_tag[:4]]
710
+ artist_txt = f"by {format_nl_list(artist_list)}"
711
  else:
712
+ artist_txt = ""
713
  character_tag = tag_by_category[TAG_CHARACTER]
714
  if character_tag:
715
+ tags = [tag for _, tag, *_ in character_tag[:4]]
716
+ character_txt = f" named {format_nl_list(tags)}"
717
  else:
718
+ character_txt = ""
719
  species_tag = tag_by_category[TAG_SPECIES]
720
  if species_tag:
721
+ species_txt = "of a" if len(character_tag) <= 1 and len(species_tag) <= 1 else "of"
722
+ species_txt += format_nl_list([tag for *_, tag in species_tag[:4]])
723
  else:
724
  if character_tag:
725
+ species_txt = (
726
+ " a character"
727
+ if len(character_tag) <= 1
728
+ else " characters"
729
+ )
730
  else:
731
+ species_txt = ""
732
  copyright_tag = tag_by_category[TAG_COPYRIGHT]
733
  if copyright_tag:
734
+ tags = [tag for _, tag, *_ in copyright_tag[:4]]
735
+ copyright_txt = f" from {format_nl_list(tags)}"
736
  else:
737
+ copyright_txt = ""
738
+
739
+ tag_string = ", ".join(tag for *_, tag in other_tags)
740
+ custom_prompt = (
741
+ f"Write a descriptive caption for this image {artist_txt}"
742
+ f"of {species_txt}"
743
+ f"{character_txt}"
744
+ f"{copyright_txt}"
745
+ f" in a formal tone. Use these tags to construct your caption: "
746
+ f"{tag_string}"
747
+ )
748
  return custom_prompt
749
 
750
+
751
  def find_tag_file(image_path):
752
  """
753
  Find the corresponding .txt file for the given image path.
754
  Handles cases where the image has a -(number) suffix.
755
  """
756
  base_name = image_path.stem
757
+ tag_file = image_path.with_suffix(".txt")
758
 
759
  if tag_file.exists():
760
  return tag_file
761
 
762
  # Handle -(number) suffix
763
+ match = re.match(r"(.+)-\d+$", base_name)
764
  if match:
765
  base_name = match.group(1)
766
+ tag_file = image_path.with_name(base_name).with_suffix(".txt")
767
  if tag_file.exists():
768
  return tag_file
769
 
770
  return None
771
 
772
+
773
  if __name__ == "__main__":
774
  main()