update joy and crawl UwU
Browse files- crawl/crawl +13 -6
- joy +205 -132
crawl/crawl
CHANGED
@@ -16,7 +16,7 @@ import platform
|
|
16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
17 |
import time
|
18 |
import argparse
|
19 |
-
from urllib.parse import urljoin
|
20 |
import requests
|
21 |
try:
|
22 |
from crawl4ai import WebCrawler # type: ignore
|
@@ -75,11 +75,18 @@ def download_image(session, image_url, save_dir, base_url):
|
|
75 |
The base URL of the page being crawled.
|
76 |
"""
|
77 |
try:
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
image_filename = os.path.basename(image_url).split("?")[0]
|
85 |
sanitized_image_filename = sanitize_filename(image_filename)
|
|
|
16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
17 |
import time
|
18 |
import argparse
|
19 |
+
from urllib.parse import urljoin, urlparse
|
20 |
import requests
|
21 |
try:
|
22 |
from crawl4ai import WebCrawler # type: ignore
|
|
|
75 |
The base URL of the page being crawled.
|
76 |
"""
|
77 |
try:
|
78 |
+
# Parse the base URL to get the scheme and netloc
|
79 |
+
parsed_base_url = urlparse(base_url)
|
80 |
+
base_image_url = (
|
81 |
+
f"{parsed_base_url.scheme}://{parsed_base_url.netloc}/"
|
82 |
+
)
|
83 |
+
|
84 |
+
# Ensure the URL has a scheme and is properly joined with
|
85 |
+
# the base image URL
|
86 |
+
if not re.match(r"^https?://", image_url):
|
87 |
+
image_url = urljoin(
|
88 |
+
base_image_url, image_url.lstrip("/")
|
89 |
+
)
|
90 |
|
91 |
image_filename = os.path.basename(image_url).split("?")[0]
|
92 |
sanitized_image_filename = sanitize_filename(image_filename)
|
joy
CHANGED
@@ -18,7 +18,6 @@ import os
|
|
18 |
import argparse
|
19 |
import re
|
20 |
import random
|
21 |
-
from collections import Counter
|
22 |
from pathlib import Path
|
23 |
from PIL import Image
|
24 |
import pillow_jxl
|
@@ -26,14 +25,14 @@ import torch
|
|
26 |
import torchvision.transforms.functional as TVF
|
27 |
from transformers import (
|
28 |
AutoModel,
|
29 |
-
AutoProcessor,
|
30 |
AutoTokenizer,
|
31 |
AutoModelForCausalLM,
|
32 |
PreTrainedTokenizer,
|
33 |
PreTrainedTokenizerFast,
|
34 |
)
|
35 |
from torch import nn
|
36 |
-
from e6db_reader import
|
|
|
37 |
|
38 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
39 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
@@ -63,8 +62,7 @@ CAPTION_TYPE_MAP = {
|
|
63 |
"Write a stable diffusion prompt for this image."
|
64 |
],
|
65 |
("training_prompt", "formal", False, True): [
|
66 |
-
"Write a stable diffusion prompt for this image within {word_count} "
|
67 |
-
"words."
|
68 |
],
|
69 |
("training_prompt", "formal", True, False): [
|
70 |
"Write a {length} stable diffusion prompt for this image."
|
@@ -82,6 +80,7 @@ CAPTION_TYPE_MAP = {
|
|
82 |
|
83 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
84 |
|
|
|
85 |
class ImageAdapter(nn.Module):
|
86 |
"""
|
87 |
Custom image adapter module for processing CLIP vision outputs.
|
@@ -118,8 +117,10 @@ class ImageAdapter(nn.Module):
|
|
118 |
self.activation = nn.GELU()
|
119 |
self.linear2 = nn.Linear(output_features, output_features)
|
120 |
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
|
121 |
-
self.pos_emb =
|
122 |
-
|
|
|
|
|
123 |
)
|
124 |
|
125 |
self.other_tokens = nn.Embedding(3, output_features)
|
@@ -136,26 +137,29 @@ class ImageAdapter(nn.Module):
|
|
136 |
torch.Tensor: Adapted image features.
|
137 |
"""
|
138 |
if self.deep_extract:
|
139 |
-
x = torch.concat(
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
|
149 |
)
|
|
|
|
|
|
|
|
|
150 |
else:
|
151 |
x = vision_outputs[-2]
|
152 |
|
153 |
x = self.ln1(x)
|
154 |
|
155 |
if self.pos_emb is not None:
|
156 |
-
assert
|
157 |
-
|
158 |
-
)
|
159 |
x = x + self.pos_emb
|
160 |
|
161 |
x = self.linear1(x)
|
@@ -167,9 +171,11 @@ class ImageAdapter(nn.Module):
|
|
167 |
x.shape[0], -1
|
168 |
)
|
169 |
)
|
170 |
-
assert other_tokens.shape == (
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
|
174 |
|
175 |
return x
|
@@ -185,6 +191,7 @@ class ImageAdapter(nn.Module):
|
|
185 |
torch.tensor([2], device=self.other_tokens.weight.device)
|
186 |
).squeeze(0)
|
187 |
|
|
|
188 |
class JoyCaptionModel:
|
189 |
"""
|
190 |
A class for generating captions for images using CLIP, LLM, and custom image adapters.
|
@@ -221,8 +228,12 @@ class JoyCaptionModel:
|
|
221 |
|
222 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
223 |
print("Loading VLM's custom vision model")
|
224 |
-
checkpoint = torch.load(
|
225 |
-
|
|
|
|
|
|
|
|
|
226 |
self.clip_model.load_state_dict(checkpoint)
|
227 |
del checkpoint
|
228 |
|
@@ -240,15 +251,11 @@ class JoyCaptionModel:
|
|
240 |
if (CHECKPOINT_PATH / "text_model").exists():
|
241 |
print("Loading VLM's custom text model")
|
242 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
243 |
-
CHECKPOINT_PATH / "text_model",
|
244 |
-
device_map=0,
|
245 |
-
torch_dtype=torch.bfloat16
|
246 |
)
|
247 |
else:
|
248 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
249 |
-
MODEL_PATH,
|
250 |
-
device_map="auto",
|
251 |
-
torch_dtype=torch.bfloat16
|
252 |
)
|
253 |
|
254 |
self.text_model.eval()
|
@@ -260,7 +267,7 @@ class JoyCaptionModel:
|
|
260 |
False,
|
261 |
False,
|
262 |
38,
|
263 |
-
False
|
264 |
)
|
265 |
self.image_adapter.load_state_dict(
|
266 |
torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
|
@@ -269,12 +276,14 @@ class JoyCaptionModel:
|
|
269 |
self.image_adapter.to("cuda")
|
270 |
|
271 |
@torch.no_grad()
|
272 |
-
def process_image(
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
|
|
|
|
278 |
"""
|
279 |
Process an input image and generate a caption based on specified parameters.
|
280 |
"""
|
@@ -283,14 +292,18 @@ class JoyCaptionModel:
|
|
283 |
if custom_prompt is not None:
|
284 |
prompt_str = custom_prompt
|
285 |
else:
|
286 |
-
prompt_str = self._get_prompt_string(
|
|
|
|
|
287 |
print(f"Prompt: {prompt_str}")
|
288 |
|
289 |
pixel_values = self._preprocess_image(input_image)
|
290 |
prompt = self._tokenize_prompt(prompt_str)
|
291 |
|
292 |
embedded_images = self._embed_image(pixel_values)
|
293 |
-
inputs_embeds, input_ids, attention_mask = self._construct_inputs(
|
|
|
|
|
294 |
|
295 |
generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask)
|
296 |
caption = self._decode_caption(generate_ids, input_ids)
|
@@ -313,7 +326,7 @@ class JoyCaptionModel:
|
|
313 |
caption_type,
|
314 |
caption_tone,
|
315 |
isinstance(length, str),
|
316 |
-
isinstance(length, int)
|
317 |
)
|
318 |
if prompt_key not in CAPTION_TYPE_MAP:
|
319 |
raise ValueError(f"Invalid caption type: {prompt_key}")
|
@@ -327,57 +340,73 @@ class JoyCaptionModel:
|
|
327 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
328 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
329 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
330 |
-
pixel_values = pixel_values.to(
|
331 |
return pixel_values
|
332 |
|
333 |
def _tokenize_prompt(self, prompt_str):
|
334 |
prompt = self.tokenizer.encode(
|
335 |
prompt_str,
|
336 |
-
return_tensors=
|
337 |
padding=False,
|
338 |
truncation=False,
|
339 |
-
add_special_tokens=False
|
340 |
)
|
341 |
return prompt
|
342 |
|
343 |
def _embed_image(self, pixel_values):
|
344 |
-
with torch.amp.autocast_mode.autocast(
|
345 |
-
vision_outputs = self.clip_model(
|
|
|
|
|
346 |
image_features = vision_outputs.hidden_states
|
347 |
embedded_images = self.image_adapter(image_features)
|
348 |
-
embedded_images = embedded_images.to(
|
349 |
return embedded_images
|
350 |
|
351 |
def _construct_inputs(self, embedded_images, prompt):
|
352 |
-
prompt_embeds = self.text_model.model.embed_tokens(prompt.to(
|
353 |
-
assert prompt_embeds.shape == (
|
|
|
|
|
|
|
|
|
354 |
f"Prompt shape is {prompt_embeds.shape}, expected "
|
355 |
f"{(1, prompt.shape[1], self.text_model.config.hidden_size)}"
|
356 |
)
|
357 |
|
358 |
embedded_bos = self.text_model.model.embed_tokens(
|
359 |
-
torch.tensor(
|
360 |
-
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
)
|
363 |
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
)
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
|
378 |
-
prompt,
|
379 |
-
torch.tensor([[self.tokenizer.eos_token_id]], dtype=torch.long),
|
380 |
-
], dim=1).to('cuda')
|
381 |
attention_mask = torch.ones_like(input_ids)
|
382 |
|
383 |
return inputs_embeds, input_ids, attention_mask
|
@@ -389,21 +418,20 @@ class JoyCaptionModel:
|
|
389 |
attention_mask=attention_mask,
|
390 |
max_new_tokens=300,
|
391 |
do_sample=True,
|
392 |
-
suppress_tokens=None
|
393 |
)
|
394 |
return generate_ids
|
395 |
|
396 |
def _decode_caption(self, generate_ids, input_ids):
|
397 |
-
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
398 |
|
399 |
-
if
|
400 |
-
|
|
|
401 |
generate_ids = generate_ids[:, :-1]
|
402 |
|
403 |
caption = self.tokenizer.batch_decode(
|
404 |
-
generate_ids,
|
405 |
-
skip_special_tokens=False,
|
406 |
-
clean_up_tokenization_spaces=False
|
407 |
)[0]
|
408 |
return caption
|
409 |
|
@@ -413,53 +441,52 @@ def main():
|
|
413 |
parser = argparse.ArgumentParser(
|
414 |
description="Generate captions for images in a directory and save them as .caption files."
|
415 |
)
|
416 |
-
parser.add_argument(
|
|
|
|
|
417 |
parser.add_argument(
|
418 |
"--caption_type",
|
419 |
type=str,
|
420 |
default="descriptive",
|
421 |
choices=["descriptive", "training_prompt", "rng-tags", "custom"],
|
422 |
-
help="Type of caption to generate."
|
423 |
)
|
424 |
parser.add_argument(
|
425 |
"--caption_tone",
|
426 |
type=str,
|
427 |
default="formal",
|
428 |
choices=["formal", "informal"],
|
429 |
-
help="Tone of the caption."
|
430 |
)
|
431 |
parser.add_argument(
|
432 |
-
"--caption_length",
|
433 |
-
type=str,
|
434 |
-
default="any",
|
435 |
-
help="Length of the caption."
|
436 |
)
|
437 |
parser.add_argument(
|
438 |
"--dont-strip-commas",
|
439 |
action="store_true",
|
440 |
-
help="If set, commas will not be stripped from the generated captions."
|
441 |
)
|
442 |
parser.add_argument(
|
443 |
"--custom_prompt",
|
444 |
type=str,
|
445 |
-
help="Custom prompt for the captioner. Use with --caption_type custom."
|
446 |
)
|
447 |
parser.add_argument(
|
448 |
-
|
449 |
-
action=
|
450 |
-
help=
|
451 |
)
|
452 |
parser.add_argument(
|
453 |
-
|
454 |
type=int,
|
455 |
-
nargs=
|
456 |
const=-1,
|
457 |
-
help=
|
458 |
)
|
459 |
parser.add_argument(
|
460 |
-
|
461 |
type=int,
|
462 |
-
help=
|
463 |
)
|
464 |
|
465 |
args = parser.parse_args()
|
@@ -468,7 +495,7 @@ def main():
|
|
468 |
if args.random_tags is not None and args.feed_from_tags is None:
|
469 |
parser.error("--random-tags can only be used when --feed-from-tags is enabled")
|
470 |
|
471 |
-
print(
|
472 |
tagset_normalizer = make_tagset_normalizer()
|
473 |
|
474 |
# Initialize and load models
|
@@ -484,7 +511,7 @@ def main():
|
|
484 |
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
|
485 |
for image_path in Path(args.directory).rglob("*"):
|
486 |
if image_path.suffix.lower() in image_extensions:
|
487 |
-
caption_file = image_path.with_suffix(
|
488 |
|
489 |
# Skip if the caption file already exists
|
490 |
if caption_file.exists():
|
@@ -501,29 +528,28 @@ def main():
|
|
501 |
custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer)
|
502 |
|
503 |
print(f"Custom prompt: {custom_prompt}")
|
504 |
-
continue
|
505 |
|
506 |
caption = joy_caption_model.process_image(
|
507 |
input_image,
|
508 |
args.caption_type,
|
509 |
args.caption_tone,
|
510 |
args.caption_length,
|
511 |
-
custom_prompt=custom_prompt
|
512 |
)
|
513 |
|
514 |
# Strip commas if the --dont-strip-commas flag is not set
|
515 |
if not args.dont_strip_commas:
|
516 |
# Existing comma stripping logic
|
517 |
-
caption = re.sub(r
|
518 |
|
519 |
# New feature: Add commas after periods if specified
|
520 |
if args.add_commas_to_sentence_ends:
|
521 |
-
caption = re.sub(r
|
522 |
|
523 |
print(f"Caption for {image_path}:\n\n{caption}\n\n")
|
524 |
|
525 |
# Save the caption to a .caption file
|
526 |
-
with open(caption_file,
|
527 |
f.write(caption)
|
528 |
print(f"Caption saved to {caption_file}")
|
529 |
|
@@ -531,6 +557,7 @@ def main():
|
|
531 |
RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
|
532 |
E6DB_DATA = Path(__file__).resolve().parent / "data"
|
533 |
|
|
|
534 |
def make_tagset_normalizer():
|
535 |
"""
|
536 |
Create a TagSetNormalizer for encoding/decoding tags to and from integers.
|
@@ -581,7 +608,6 @@ def make_tagset_normalizer():
|
|
581 |
return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
|
582 |
|
583 |
|
584 |
-
|
585 |
def format_nl_list(l):
|
586 |
n = len(l)
|
587 |
assert n > 0
|
@@ -589,36 +615,51 @@ def format_nl_list(l):
|
|
589 |
return l[0]
|
590 |
elif n == 2:
|
591 |
return f"{l[0]} and {l[1]}"
|
592 |
-
else:
|
593 |
*head, last = l
|
594 |
-
return
|
595 |
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
|
|
601 |
TAG_FREQ_THRESH = 0
|
602 |
|
|
|
603 |
def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
tag_file = find_tag_file(image_path)
|
605 |
if tag_file is None:
|
606 |
return None
|
607 |
|
608 |
-
with open(tag_file,
|
609 |
-
tags = f.read().lower().split(
|
610 |
|
611 |
tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
|
612 |
encode = tagset_normalizer.tag_normalizer.encode
|
613 |
|
614 |
# These lists contain tuples (freq, tag, tag_id)
|
615 |
-
tag_by_category
|
616 |
-
|
617 |
-
|
|
|
|
|
618 |
for tag in tags:
|
619 |
tag = tag.strip()
|
620 |
# Encode the tag into a numerical id
|
621 |
-
tag_id = encode(tag.replace(
|
622 |
if tag_id is None:
|
623 |
other_tags.append((0, tag, None))
|
624 |
implied.update(tagset_normalizer.implications_rej.get(tag_id, ()))
|
@@ -633,69 +674,101 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
|
|
633 |
freq = tag_rank_to_freq(tag_id)
|
634 |
if freq < TAG_FREQ_THRESH:
|
635 |
continue
|
636 |
-
tag_by_category.get(cat_id, other_tags).append((freq, tag, tag_id))
|
637 |
|
638 |
-
other_tags = sorted(
|
|
|
|
|
|
|
|
|
639 |
for cat_id, cat_list in tag_by_category.items():
|
640 |
-
tag_by_category[cat_id] = sorted(
|
|
|
|
|
|
|
|
|
641 |
|
642 |
if args.random_tags is not None:
|
643 |
# Randomly select tags if --random-tags is specified
|
644 |
num_tags = min(args.random_tags, len(other_tags))
|
645 |
-
other_tags = random.sample(
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
elif args.feed_from_tags > 0:
|
647 |
# Use specified number of tags if --feed-from-tags has a positive value
|
648 |
-
other_tags = other_tags[:args.feed_from_tags]
|
649 |
|
650 |
# Prepare sentence pieces
|
651 |
artist_tag = tag_by_category[TAG_ARTIST]
|
652 |
if artist_tag:
|
653 |
-
|
|
|
|
|
654 |
else:
|
655 |
-
artist_txt =
|
656 |
character_tag = tag_by_category[TAG_CHARACTER]
|
657 |
if character_tag:
|
658 |
-
|
|
|
659 |
else:
|
660 |
-
character_txt =
|
661 |
species_tag = tag_by_category[TAG_SPECIES]
|
662 |
if species_tag:
|
663 |
-
species_txt =
|
|
|
664 |
else:
|
665 |
if character_tag:
|
666 |
-
species_txt =
|
|
|
|
|
|
|
|
|
667 |
else:
|
668 |
-
species_txt =
|
669 |
copyright_tag = tag_by_category[TAG_COPYRIGHT]
|
670 |
if copyright_tag:
|
671 |
-
|
|
|
672 |
else:
|
673 |
-
copyright_txt =
|
674 |
-
|
675 |
-
tag_string =
|
676 |
-
custom_prompt =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
return custom_prompt
|
678 |
|
|
|
679 |
def find_tag_file(image_path):
|
680 |
"""
|
681 |
Find the corresponding .txt file for the given image path.
|
682 |
Handles cases where the image has a -(number) suffix.
|
683 |
"""
|
684 |
base_name = image_path.stem
|
685 |
-
tag_file = image_path.with_suffix(
|
686 |
|
687 |
if tag_file.exists():
|
688 |
return tag_file
|
689 |
|
690 |
# Handle -(number) suffix
|
691 |
-
match = re.match(r
|
692 |
if match:
|
693 |
base_name = match.group(1)
|
694 |
-
tag_file = image_path.with_name(base_name).with_suffix(
|
695 |
if tag_file.exists():
|
696 |
return tag_file
|
697 |
|
698 |
return None
|
699 |
|
|
|
700 |
if __name__ == "__main__":
|
701 |
main()
|
|
|
18 |
import argparse
|
19 |
import re
|
20 |
import random
|
|
|
21 |
from pathlib import Path
|
22 |
from PIL import Image
|
23 |
import pillow_jxl
|
|
|
25 |
import torchvision.transforms.functional as TVF
|
26 |
from transformers import (
|
27 |
AutoModel,
|
|
|
28 |
AutoTokenizer,
|
29 |
AutoModelForCausalLM,
|
30 |
PreTrainedTokenizer,
|
31 |
PreTrainedTokenizerFast,
|
32 |
)
|
33 |
from torch import nn
|
34 |
+
from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
|
35 |
+
from typing import List, Tuple, Dict
|
36 |
|
37 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
38 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
|
|
62 |
"Write a stable diffusion prompt for this image."
|
63 |
],
|
64 |
("training_prompt", "formal", False, True): [
|
65 |
+
"Write a stable diffusion prompt for this image within {word_count} " "words."
|
|
|
66 |
],
|
67 |
("training_prompt", "formal", True, False): [
|
68 |
"Write a {length} stable diffusion prompt for this image."
|
|
|
80 |
|
81 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
82 |
|
83 |
+
|
84 |
class ImageAdapter(nn.Module):
|
85 |
"""
|
86 |
Custom image adapter module for processing CLIP vision outputs.
|
|
|
117 |
self.activation = nn.GELU()
|
118 |
self.linear2 = nn.Linear(output_features, output_features)
|
119 |
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
|
120 |
+
self.pos_emb = (
|
121 |
+
None
|
122 |
+
if not pos_emb
|
123 |
+
else nn.Parameter(torch.zeros(num_image_tokens, input_features))
|
124 |
)
|
125 |
|
126 |
self.other_tokens = nn.Embedding(3, output_features)
|
|
|
137 |
torch.Tensor: Adapted image features.
|
138 |
"""
|
139 |
if self.deep_extract:
|
140 |
+
x = torch.concat(
|
141 |
+
(
|
142 |
+
vision_outputs[-2],
|
143 |
+
vision_outputs[3],
|
144 |
+
vision_outputs[7],
|
145 |
+
vision_outputs[13],
|
146 |
+
vision_outputs[20],
|
147 |
+
),
|
148 |
+
dim=-1,
|
|
|
149 |
)
|
150 |
+
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
|
151 |
+
assert (
|
152 |
+
x.shape[-1] == vision_outputs[-2].shape[-1] * 5
|
153 |
+
), f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
|
154 |
else:
|
155 |
x = vision_outputs[-2]
|
156 |
|
157 |
x = self.ln1(x)
|
158 |
|
159 |
if self.pos_emb is not None:
|
160 |
+
assert (
|
161 |
+
x.shape[-2:] == self.pos_emb.shape
|
162 |
+
), f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
|
163 |
x = x + self.pos_emb
|
164 |
|
165 |
x = self.linear1(x)
|
|
|
171 |
x.shape[0], -1
|
172 |
)
|
173 |
)
|
174 |
+
assert other_tokens.shape == (
|
175 |
+
x.shape[0],
|
176 |
+
2,
|
177 |
+
x.shape[2],
|
178 |
+
), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
|
179 |
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
|
180 |
|
181 |
return x
|
|
|
191 |
torch.tensor([2], device=self.other_tokens.weight.device)
|
192 |
).squeeze(0)
|
193 |
|
194 |
+
|
195 |
class JoyCaptionModel:
|
196 |
"""
|
197 |
A class for generating captions for images using CLIP, LLM, and custom image adapters.
|
|
|
228 |
|
229 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
230 |
print("Loading VLM's custom vision model")
|
231 |
+
checkpoint = torch.load(
|
232 |
+
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
|
233 |
+
)
|
234 |
+
checkpoint = {
|
235 |
+
k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()
|
236 |
+
}
|
237 |
self.clip_model.load_state_dict(checkpoint)
|
238 |
del checkpoint
|
239 |
|
|
|
251 |
if (CHECKPOINT_PATH / "text_model").exists():
|
252 |
print("Loading VLM's custom text model")
|
253 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
254 |
+
CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
|
|
|
|
|
255 |
)
|
256 |
else:
|
257 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
258 |
+
MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16
|
|
|
|
|
259 |
)
|
260 |
|
261 |
self.text_model.eval()
|
|
|
267 |
False,
|
268 |
False,
|
269 |
38,
|
270 |
+
False,
|
271 |
)
|
272 |
self.image_adapter.load_state_dict(
|
273 |
torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
|
|
|
276 |
self.image_adapter.to("cuda")
|
277 |
|
278 |
@torch.no_grad()
|
279 |
+
def process_image(
|
280 |
+
self,
|
281 |
+
input_image: Image.Image,
|
282 |
+
caption_type: str,
|
283 |
+
caption_tone: str,
|
284 |
+
caption_length: str | int,
|
285 |
+
custom_prompt: str | None = None,
|
286 |
+
) -> str:
|
287 |
"""
|
288 |
Process an input image and generate a caption based on specified parameters.
|
289 |
"""
|
|
|
292 |
if custom_prompt is not None:
|
293 |
prompt_str = custom_prompt
|
294 |
else:
|
295 |
+
prompt_str = self._get_prompt_string(
|
296 |
+
caption_type, caption_tone, caption_length
|
297 |
+
)
|
298 |
print(f"Prompt: {prompt_str}")
|
299 |
|
300 |
pixel_values = self._preprocess_image(input_image)
|
301 |
prompt = self._tokenize_prompt(prompt_str)
|
302 |
|
303 |
embedded_images = self._embed_image(pixel_values)
|
304 |
+
inputs_embeds, input_ids, attention_mask = self._construct_inputs(
|
305 |
+
embedded_images, prompt
|
306 |
+
)
|
307 |
|
308 |
generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask)
|
309 |
caption = self._decode_caption(generate_ids, input_ids)
|
|
|
326 |
caption_type,
|
327 |
caption_tone,
|
328 |
isinstance(length, str),
|
329 |
+
isinstance(length, int),
|
330 |
)
|
331 |
if prompt_key not in CAPTION_TYPE_MAP:
|
332 |
raise ValueError(f"Invalid caption type: {prompt_key}")
|
|
|
340 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
341 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
342 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
343 |
+
pixel_values = pixel_values.to("cuda")
|
344 |
return pixel_values
|
345 |
|
346 |
def _tokenize_prompt(self, prompt_str):
|
347 |
prompt = self.tokenizer.encode(
|
348 |
prompt_str,
|
349 |
+
return_tensors="pt",
|
350 |
padding=False,
|
351 |
truncation=False,
|
352 |
+
add_special_tokens=False,
|
353 |
)
|
354 |
return prompt
|
355 |
|
356 |
def _embed_image(self, pixel_values):
|
357 |
+
with torch.amp.autocast_mode.autocast("cuda", enabled=True):
|
358 |
+
vision_outputs = self.clip_model(
|
359 |
+
pixel_values=pixel_values, output_hidden_states=True
|
360 |
+
)
|
361 |
image_features = vision_outputs.hidden_states
|
362 |
embedded_images = self.image_adapter(image_features)
|
363 |
+
embedded_images = embedded_images.to("cuda")
|
364 |
return embedded_images
|
365 |
|
366 |
def _construct_inputs(self, embedded_images, prompt):
|
367 |
+
prompt_embeds = self.text_model.model.embed_tokens(prompt.to("cuda"))
|
368 |
+
assert prompt_embeds.shape == (
|
369 |
+
1,
|
370 |
+
prompt.shape[1],
|
371 |
+
self.text_model.config.hidden_size,
|
372 |
+
), (
|
373 |
f"Prompt shape is {prompt_embeds.shape}, expected "
|
374 |
f"{(1, prompt.shape[1], self.text_model.config.hidden_size)}"
|
375 |
)
|
376 |
|
377 |
embedded_bos = self.text_model.model.embed_tokens(
|
378 |
+
torch.tensor(
|
379 |
+
[[self.tokenizer.bos_token_id]],
|
380 |
+
device=self.text_model.device,
|
381 |
+
dtype=torch.int64,
|
382 |
+
)
|
383 |
+
)
|
384 |
+
|
385 |
+
eot_embed = (
|
386 |
+
self.image_adapter.get_eot_embedding()
|
387 |
+
.unsqueeze(0)
|
388 |
+
.to(dtype=self.text_model.dtype)
|
389 |
)
|
390 |
|
391 |
+
inputs_embeds = torch.cat(
|
392 |
+
[
|
393 |
+
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
394 |
+
embedded_images.to(dtype=embedded_bos.dtype),
|
395 |
+
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
|
396 |
+
eot_embed.expand(embedded_images.shape[0], -1, -1),
|
397 |
+
],
|
398 |
+
dim=1,
|
399 |
)
|
400 |
|
401 |
+
input_ids = torch.cat(
|
402 |
+
[
|
403 |
+
torch.tensor([[self.tokenizer.bos_token_id]], dtype=torch.long),
|
404 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
|
405 |
+
prompt,
|
406 |
+
torch.tensor([[self.tokenizer.eos_token_id]], dtype=torch.long),
|
407 |
+
],
|
408 |
+
dim=1,
|
409 |
+
).to("cuda")
|
|
|
|
|
|
|
|
|
410 |
attention_mask = torch.ones_like(input_ids)
|
411 |
|
412 |
return inputs_embeds, input_ids, attention_mask
|
|
|
418 |
attention_mask=attention_mask,
|
419 |
max_new_tokens=300,
|
420 |
do_sample=True,
|
421 |
+
suppress_tokens=None,
|
422 |
)
|
423 |
return generate_ids
|
424 |
|
425 |
def _decode_caption(self, generate_ids, input_ids):
|
426 |
+
generate_ids = generate_ids[:, input_ids.shape[1] :]
|
427 |
|
428 |
+
if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][
|
429 |
+
-1
|
430 |
+
] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
431 |
generate_ids = generate_ids[:, :-1]
|
432 |
|
433 |
caption = self.tokenizer.batch_decode(
|
434 |
+
generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
|
|
|
435 |
)[0]
|
436 |
return caption
|
437 |
|
|
|
441 |
parser = argparse.ArgumentParser(
|
442 |
description="Generate captions for images in a directory and save them as .caption files."
|
443 |
)
|
444 |
+
parser.add_argument(
|
445 |
+
"directory", type=str, help="Target directory containing images."
|
446 |
+
)
|
447 |
parser.add_argument(
|
448 |
"--caption_type",
|
449 |
type=str,
|
450 |
default="descriptive",
|
451 |
choices=["descriptive", "training_prompt", "rng-tags", "custom"],
|
452 |
+
help="Type of caption to generate.",
|
453 |
)
|
454 |
parser.add_argument(
|
455 |
"--caption_tone",
|
456 |
type=str,
|
457 |
default="formal",
|
458 |
choices=["formal", "informal"],
|
459 |
+
help="Tone of the caption.",
|
460 |
)
|
461 |
parser.add_argument(
|
462 |
+
"--caption_length", type=str, default="any", help="Length of the caption."
|
|
|
|
|
|
|
463 |
)
|
464 |
parser.add_argument(
|
465 |
"--dont-strip-commas",
|
466 |
action="store_true",
|
467 |
+
help="If set, commas will not be stripped from the generated captions.",
|
468 |
)
|
469 |
parser.add_argument(
|
470 |
"--custom_prompt",
|
471 |
type=str,
|
472 |
+
help="Custom prompt for the captioner. Use with --caption_type custom.",
|
473 |
)
|
474 |
parser.add_argument(
|
475 |
+
"--add-commas-to-sentence-ends",
|
476 |
+
action="store_true",
|
477 |
+
help="Add commas after periods in sentences",
|
478 |
)
|
479 |
parser.add_argument(
|
480 |
+
"--feed-from-tags",
|
481 |
type=int,
|
482 |
+
nargs="?",
|
483 |
const=-1,
|
484 |
+
help="Use .txt files with the same base filename as the images as input to the captioner. Optionally specify the number of tags to use.",
|
485 |
)
|
486 |
parser.add_argument(
|
487 |
+
"--random-tags",
|
488 |
type=int,
|
489 |
+
help="Randomly select n number of tags. Only works if --feed-from-tags is enabled.",
|
490 |
)
|
491 |
|
492 |
args = parser.parse_args()
|
|
|
495 |
if args.random_tags is not None and args.feed_from_tags is None:
|
496 |
parser.error("--random-tags can only be used when --feed-from-tags is enabled")
|
497 |
|
498 |
+
print("Loading e621 tag data")
|
499 |
tagset_normalizer = make_tagset_normalizer()
|
500 |
|
501 |
# Initialize and load models
|
|
|
511 |
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
|
512 |
for image_path in Path(args.directory).rglob("*"):
|
513 |
if image_path.suffix.lower() in image_extensions:
|
514 |
+
caption_file = image_path.with_suffix(".caption")
|
515 |
|
516 |
# Skip if the caption file already exists
|
517 |
if caption_file.exists():
|
|
|
528 |
custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer)
|
529 |
|
530 |
print(f"Custom prompt: {custom_prompt}")
|
|
|
531 |
|
532 |
caption = joy_caption_model.process_image(
|
533 |
input_image,
|
534 |
args.caption_type,
|
535 |
args.caption_tone,
|
536 |
args.caption_length,
|
537 |
+
custom_prompt=custom_prompt,
|
538 |
)
|
539 |
|
540 |
# Strip commas if the --dont-strip-commas flag is not set
|
541 |
if not args.dont_strip_commas:
|
542 |
# Existing comma stripping logic
|
543 |
+
caption = re.sub(r",\s*([^\d])", r" \1", caption)
|
544 |
|
545 |
# New feature: Add commas after periods if specified
|
546 |
if args.add_commas_to_sentence_ends:
|
547 |
+
caption = re.sub(r"(\.)(\s+)([A-Z])", r"\1,\2\3", caption)
|
548 |
|
549 |
print(f"Caption for {image_path}:\n\n{caption}\n\n")
|
550 |
|
551 |
# Save the caption to a .caption file
|
552 |
+
with open(caption_file, "w", encoding="utf-8") as f:
|
553 |
f.write(caption)
|
554 |
print(f"Caption saved to {caption_file}")
|
555 |
|
|
|
557 |
RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$")
|
558 |
E6DB_DATA = Path(__file__).resolve().parent / "data"
|
559 |
|
560 |
+
|
561 |
def make_tagset_normalizer():
|
562 |
"""
|
563 |
Create a TagSetNormalizer for encoding/decoding tags to and from integers.
|
|
|
608 |
return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
|
609 |
|
610 |
|
|
|
611 |
def format_nl_list(l):
|
612 |
n = len(l)
|
613 |
assert n > 0
|
|
|
615 |
return l[0]
|
616 |
elif n == 2:
|
617 |
return f"{l[0]} and {l[1]}"
|
618 |
+
else: # n > 2
|
619 |
*head, last = l
|
620 |
+
return ", ".join(head) + ", and " + last
|
621 |
|
622 |
+
|
623 |
+
TAG_SPECIES = tag_category2id["species"]
|
624 |
+
TAG_CHARACTER = tag_category2id["character"]
|
625 |
+
TAG_ARTIST = tag_category2id["artist"]
|
626 |
+
TAG_COPYRIGHT = tag_category2id["copyright"]
|
627 |
+
TAG_META = tag_category2id["meta"]
|
628 |
TAG_FREQ_THRESH = 0
|
629 |
|
630 |
+
|
631 |
def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer):
|
632 |
+
"""
|
633 |
+
Generates a prompt from tags associated with the given image.
|
634 |
+
|
635 |
+
Args:
|
636 |
+
args: Additional arguments for the function.
|
637 |
+
image_path (Path): The path to the image file.
|
638 |
+
tagset_normalizer (TagSetNormalizer): An instance to normalize the tag set.
|
639 |
+
|
640 |
+
Returns:
|
641 |
+
None
|
642 |
+
"""
|
643 |
tag_file = find_tag_file(image_path)
|
644 |
if tag_file is None:
|
645 |
return None
|
646 |
|
647 |
+
with open(tag_file, "r", encoding="utf-8") as f:
|
648 |
+
tags = f.read().lower().split(",")
|
649 |
|
650 |
tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories
|
651 |
encode = tagset_normalizer.tag_normalizer.encode
|
652 |
|
653 |
# These lists contain tuples (freq, tag, tag_id)
|
654 |
+
tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
|
655 |
+
cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
|
656 |
+
}
|
657 |
+
other_tags: List[Tuple[int, str, int]] = []
|
658 |
+
implied: set = set()
|
659 |
for tag in tags:
|
660 |
tag = tag.strip()
|
661 |
# Encode the tag into a numerical id
|
662 |
+
tag_id = encode(tag.replace(" ", "_"))
|
663 |
if tag_id is None:
|
664 |
other_tags.append((0, tag, None))
|
665 |
implied.update(tagset_normalizer.implications_rej.get(tag_id, ()))
|
|
|
674 |
freq = tag_rank_to_freq(tag_id)
|
675 |
if freq < TAG_FREQ_THRESH:
|
676 |
continue
|
677 |
+
tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id))
|
678 |
|
679 |
+
other_tags = sorted(
|
680 |
+
(int(freq), tag, tag_id)
|
681 |
+
for freq, tag, tag_id in other_tags
|
682 |
+
if tag_id not in implied
|
683 |
+
)
|
684 |
for cat_id, cat_list in tag_by_category.items():
|
685 |
+
tag_by_category[cat_id] = sorted(
|
686 |
+
(int(freq), tag, tag_id)
|
687 |
+
for freq, tag, tag_id in cat_list
|
688 |
+
if tag_id not in implied
|
689 |
+
)
|
690 |
|
691 |
if args.random_tags is not None:
|
692 |
# Randomly select tags if --random-tags is specified
|
693 |
num_tags = min(args.random_tags, len(other_tags))
|
694 |
+
other_tags = random.sample(
|
695 |
+
[
|
696 |
+
(i, tag, tag_id)
|
697 |
+
for i, tag, tag_id in enumerate(tags[: round(args.random_tags * 1.5)])
|
698 |
+
],
|
699 |
+
num_tags,
|
700 |
+
)
|
701 |
elif args.feed_from_tags > 0:
|
702 |
# Use specified number of tags if --feed-from-tags has a positive value
|
703 |
+
other_tags = other_tags[: args.feed_from_tags]
|
704 |
|
705 |
# Prepare sentence pieces
|
706 |
artist_tag = tag_by_category[TAG_ARTIST]
|
707 |
if artist_tag:
|
708 |
+
artist_list = [str(tag).removeprefix('by ')
|
709 |
+
for *_, tag in artist_tag[:4]]
|
710 |
+
artist_txt = f"by {format_nl_list(artist_list)}"
|
711 |
else:
|
712 |
+
artist_txt = ""
|
713 |
character_tag = tag_by_category[TAG_CHARACTER]
|
714 |
if character_tag:
|
715 |
+
tags = [tag for _, tag, *_ in character_tag[:4]]
|
716 |
+
character_txt = f" named {format_nl_list(tags)}"
|
717 |
else:
|
718 |
+
character_txt = ""
|
719 |
species_tag = tag_by_category[TAG_SPECIES]
|
720 |
if species_tag:
|
721 |
+
species_txt = "of a" if len(character_tag) <= 1 and len(species_tag) <= 1 else "of"
|
722 |
+
species_txt += format_nl_list([tag for *_, tag in species_tag[:4]])
|
723 |
else:
|
724 |
if character_tag:
|
725 |
+
species_txt = (
|
726 |
+
" a character"
|
727 |
+
if len(character_tag) <= 1
|
728 |
+
else " characters"
|
729 |
+
)
|
730 |
else:
|
731 |
+
species_txt = ""
|
732 |
copyright_tag = tag_by_category[TAG_COPYRIGHT]
|
733 |
if copyright_tag:
|
734 |
+
tags = [tag for _, tag, *_ in copyright_tag[:4]]
|
735 |
+
copyright_txt = f" from {format_nl_list(tags)}"
|
736 |
else:
|
737 |
+
copyright_txt = ""
|
738 |
+
|
739 |
+
tag_string = ", ".join(tag for *_, tag in other_tags)
|
740 |
+
custom_prompt = (
|
741 |
+
f"Write a descriptive caption for this image {artist_txt}"
|
742 |
+
f"of {species_txt}"
|
743 |
+
f"{character_txt}"
|
744 |
+
f"{copyright_txt}"
|
745 |
+
f" in a formal tone. Use these tags to construct your caption: "
|
746 |
+
f"{tag_string}"
|
747 |
+
)
|
748 |
return custom_prompt
|
749 |
|
750 |
+
|
751 |
def find_tag_file(image_path):
|
752 |
"""
|
753 |
Find the corresponding .txt file for the given image path.
|
754 |
Handles cases where the image has a -(number) suffix.
|
755 |
"""
|
756 |
base_name = image_path.stem
|
757 |
+
tag_file = image_path.with_suffix(".txt")
|
758 |
|
759 |
if tag_file.exists():
|
760 |
return tag_file
|
761 |
|
762 |
# Handle -(number) suffix
|
763 |
+
match = re.match(r"(.+)-\d+$", base_name)
|
764 |
if match:
|
765 |
base_name = match.group(1)
|
766 |
+
tag_file = image_path.with_name(base_name).with_suffix(".txt")
|
767 |
if tag_file.exists():
|
768 |
return tag_file
|
769 |
|
770 |
return None
|
771 |
|
772 |
+
|
773 |
if __name__ == "__main__":
|
774 |
main()
|