|
|
|
|
|
|
|
""" |
|
JoyCaption Alpha One |
|
|
|
This module provides functionality for generating captions for images using a |
|
combination of CLIP, LLM, and custom image adapters. It supports various |
|
caption types, tones, and lengths. |
|
|
|
The main components include: |
|
- Loading and initializing models (CLIP, LLM, image adapter) |
|
- Processing images and generating captions |
|
- Command-line interface for batch processing images in a directory |
|
""" |
|
|
|
import os |
|
import argparse |
|
import re |
|
import random |
|
from pathlib import Path |
|
from typing import List, Tuple, Dict |
|
from PIL import Image |
|
import pillow_jxl |
|
import torch |
|
import torchvision.transforms.functional as TVF |
|
from transformers import ( |
|
AutoModel, |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
PreTrainedTokenizer, |
|
PreTrainedTokenizerFast, |
|
) |
|
from torch import nn |
|
from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq |
|
|
|
CLIP_PATH = "google/siglip-so400m-patch14-384" |
|
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B" |
|
CHECKPOINT_PATH = Path(__file__).resolve().parent / "9em124t2-499968" |
|
CAPTION_TYPE_MAP = { |
|
("descriptive", "formal", False, False): [ |
|
"Write a descriptive caption for this image in a formal tone." |
|
], |
|
("descriptive", "formal", False, True): [ |
|
"Write a descriptive caption for this image in a formal tone within " |
|
"{word_count} words." |
|
], |
|
("descriptive", "formal", True, False): [ |
|
"Write a {length} descriptive caption for this image in a formal tone." |
|
], |
|
("descriptive", "informal", False, False): [ |
|
"Write a descriptive caption for this image in a casual tone." |
|
], |
|
("descriptive", "informal", False, True): [ |
|
"Write a descriptive caption for this image in a casual tone within " |
|
"{word_count} words." |
|
], |
|
("descriptive", "informal", True, False): [ |
|
"Write a {length} descriptive caption for this image in a casual tone." |
|
], |
|
("training_prompt", "formal", False, False): [ |
|
"Write a stable diffusion prompt for this image." |
|
], |
|
("training_prompt", "formal", False, True): [ |
|
"Write a stable diffusion prompt for this image within " + |
|
"{word_count} words." |
|
], |
|
("training_prompt", "formal", True, False): [ |
|
"Write a {length} stable diffusion prompt for this image." |
|
], |
|
("rng-tags", "formal", False, False): [ |
|
"Write a list of Booru tags for this image." |
|
], |
|
("rng-tags", "formal", False, True): [ |
|
"Write a list of Booru tags for this image within {word_count} words." |
|
], |
|
("rng-tags", "formal", True, False): [ |
|
"Write a {length} list of Booru tags for this image." |
|
], |
|
} |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
class ImageAdapter(nn.Module): |
|
""" |
|
Custom image adapter module for processing CLIP vision outputs. |
|
|
|
This module adapts the output of a CLIP vision model to be compatible with |
|
a text model. It supports optional layer normalization, positional |
|
embeddings, and deep feature extraction. |
|
|
|
Args: |
|
input_features (int): |
|
Number of input features from the vision model. |
|
output_features (int): |
|
Number of output features to match the text model. |
|
ln1 (bool): |
|
Whether to use layer normalization. |
|
pos_emb (bool): |
|
Whether to use positional embeddings. |
|
num_image_tokens (int): |
|
Number of image tokens. |
|
deep_extract (bool): |
|
Whether to use deep feature extraction. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_features: int, |
|
output_features: int, |
|
ln1: bool, |
|
pos_emb: bool, |
|
num_image_tokens: int, |
|
deep_extract: bool, |
|
): |
|
super().__init__() |
|
self.deep_extract = deep_extract |
|
|
|
if self.deep_extract: |
|
input_features = input_features * 5 |
|
|
|
self.linear1 = nn.Linear(input_features, output_features) |
|
self.activation = nn.GELU() |
|
self.linear2 = nn.Linear(output_features, output_features) |
|
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) |
|
self.pos_emb = ( |
|
None |
|
if not pos_emb |
|
else nn.Parameter(torch.zeros(num_image_tokens, input_features)) |
|
) |
|
|
|
self.other_tokens = nn.Embedding(3, output_features) |
|
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
""" |
|
Forward pass of the image adapter. |
|
|
|
Args: |
|
vision_outputs (torch.Tensor): |
|
Output tensor from the CLIP vision model. |
|
|
|
Returns: |
|
torch.Tensor: Adapted image features. |
|
""" |
|
if self.deep_extract: |
|
x = torch.concat( |
|
( |
|
vision_outputs[-2], |
|
vision_outputs[3], |
|
vision_outputs[7], |
|
vision_outputs[13], |
|
vision_outputs[20], |
|
), |
|
dim=-1, |
|
) |
|
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" |
|
expected_shape = vision_outputs[-2].shape[-1] * 5 |
|
assert ( |
|
x.shape[-1] == expected_shape |
|
), f"Expected {expected_shape}, got {x.shape[-1]}" |
|
else: |
|
x = vision_outputs[-2] |
|
|
|
x = self.ln1(x) |
|
|
|
if self.pos_emb is not None: |
|
assert ( |
|
x.shape[-2:] == self.pos_emb.shape |
|
), f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" |
|
x = x + self.pos_emb |
|
|
|
x = self.linear1(x) |
|
x = self.activation(x) |
|
x = self.linear2(x) |
|
|
|
other_tokens = self.other_tokens( |
|
torch.tensor([0, 1], device=self.other_tokens.weight.device) |
|
.expand(x.shape[0], -1) |
|
) |
|
assert other_tokens.shape == ( |
|
x.shape[0], |
|
2, |
|
x.shape[2], |
|
), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" |
|
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) |
|
|
|
return x |
|
|
|
def get_eot_embedding(self): |
|
""" |
|
Get the end-of-text embedding. |
|
|
|
Returns: |
|
torch.Tensor: The end-of-text embedding. |
|
""" |
|
return self.other_tokens( |
|
torch.tensor([2], device=self.other_tokens.weight.device) |
|
).squeeze(0) |
|
|
|
|
|
class JoyCaptionModel: |
|
""" |
|
A class for generating captions for images using CLIP, LLM, |
|
and custom image adapters. |
|
|
|
This class encapsulates the functionality to load and initialize |
|
various models (CLIP, LLM, image adapter) and use them to process |
|
images and generate captions. |
|
|
|
It supports different caption types, tones, and lengths. |
|
|
|
Attributes: |
|
clip_model: The CLIP vision model for processing images. |
|
text_model: The language model for generating captions. |
|
image_adapter: Custom adapter for processing CLIP vision outputs. |
|
tokenizer: Tokenizer for the language model. |
|
|
|
Methods: |
|
load_models(): Load and initialize all required models. |
|
process_image(input_image, caption_type, caption_tone, caption_length): |
|
Process an input image and generate a caption |
|
based on specified parameters. |
|
""" |
|
|
|
def __init__(self): |
|
self.clip_model = None |
|
self.text_model = None |
|
self.image_adapter = None |
|
self.tokenizer = None |
|
|
|
def load_models(self): |
|
""" |
|
Load and initialize all required models (CLIP, LLM, image adapter). |
|
""" |
|
print("Loading CLIP") |
|
self.clip_model = AutoModel.from_pretrained(CLIP_PATH) |
|
self.clip_model = self.clip_model.vision_model |
|
|
|
if (CHECKPOINT_PATH / "clip_model.pt").exists(): |
|
print("Loading VLM's custom vision model") |
|
checkpoint = torch.load( |
|
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu" |
|
) |
|
checkpoint = { |
|
k.replace("_orig_mod.module.", ""): v |
|
for k, v in checkpoint.items() |
|
} |
|
self.clip_model.load_state_dict(checkpoint) |
|
del checkpoint |
|
|
|
self.clip_model.eval() |
|
self.clip_model.requires_grad_(False) |
|
self.clip_model.to("cuda") |
|
|
|
print("Loading tokenizer") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_PATH, use_fast=False |
|
) |
|
assert isinstance( |
|
self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) |
|
) |
|
|
|
print("Loading LLM") |
|
if (CHECKPOINT_PATH / "text_model").exists(): |
|
print("Loading VLM's custom text model") |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", |
|
device_map=0, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
else: |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
self.text_model.eval() |
|
|
|
print("Loading image adapter") |
|
self.image_adapter = ImageAdapter( |
|
self.clip_model.config.hidden_size, |
|
self.text_model.config.hidden_size, |
|
False, |
|
False, |
|
38, |
|
False, |
|
) |
|
self.image_adapter.load_state_dict( |
|
torch.load( |
|
CHECKPOINT_PATH / "image_adapter.pt", |
|
map_location="cpu" |
|
) |
|
) |
|
self.image_adapter.eval() |
|
self.image_adapter.to("cuda") |
|
|
|
@torch.no_grad() |
|
def process_image( |
|
self, |
|
input_image: Image.Image, |
|
caption_type: str, |
|
caption_tone: str, |
|
caption_length: str | int, |
|
custom_prompt: str | None = None, |
|
) -> str: |
|
""" |
|
Process an input image and generate a caption based on specified |
|
parameters. |
|
""" |
|
torch.cuda.empty_cache() |
|
|
|
if custom_prompt is not None: |
|
prompt_str = custom_prompt |
|
else: |
|
prompt_str = self._get_prompt_string( |
|
caption_type, caption_tone, caption_length |
|
) |
|
print(f"Prompt: {prompt_str}") |
|
|
|
pixel_values = self._preprocess_image(input_image) |
|
prompt = self._tokenize_prompt(prompt_str) |
|
|
|
embedded_images = self._embed_image(pixel_values) |
|
inputs_embeds, input_ids, attention_mask = self._construct_inputs( |
|
embedded_images, prompt |
|
) |
|
|
|
generate_ids = self._generate_caption(inputs_embeds, |
|
input_ids, |
|
attention_mask) |
|
caption = self._decode_caption(generate_ids, input_ids) |
|
|
|
return caption.strip() |
|
|
|
def generate_valid_caption( |
|
self, |
|
input_image: Image.Image, |
|
caption_type: str, |
|
caption_tone: str, |
|
caption_length: str | int, |
|
custom_prompt: str | None = None, |
|
) -> str: |
|
""" |
|
Generate a valid caption, retrying if the caption contains only special |
|
characters or does not end with a period, exclamation mark, or |
|
question mark. |
|
""" |
|
while True: |
|
caption = self.process_image( |
|
input_image, caption_type, caption_tone, |
|
caption_length, custom_prompt |
|
) |
|
|
|
|
|
|
|
|
|
if re.search(r'\w', caption) and caption[-1] in {'.', '!', '?'}: |
|
return caption |
|
print("Generated caption is invalid. Retrying...") |
|
|
|
def _get_prompt_string(self, caption_type, caption_tone, caption_length): |
|
length = None if caption_length == "any" else caption_length |
|
|
|
if isinstance(length, str): |
|
try: |
|
length = int(length) |
|
except ValueError: |
|
pass |
|
|
|
if caption_type in {"rng-tags", "training_prompt"}: |
|
caption_tone = "formal" |
|
|
|
prompt_key = ( |
|
caption_type, |
|
caption_tone, |
|
isinstance(length, str), |
|
isinstance(length, int), |
|
) |
|
if prompt_key not in CAPTION_TYPE_MAP: |
|
raise ValueError(f"Invalid caption type: {prompt_key}") |
|
|
|
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format( |
|
length=length, word_count=length |
|
) |
|
return prompt_str |
|
|
|
def _preprocess_image(self, input_image): |
|
image = input_image.resize((384, 384), Image.LANCZOS) |
|
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
|
pixel_values = pixel_values.to("cuda") |
|
return pixel_values |
|
|
|
def _tokenize_prompt(self, prompt_str): |
|
prompt = self.tokenizer.encode( |
|
prompt_str, |
|
return_tensors="pt", |
|
padding=False, |
|
truncation=False, |
|
add_special_tokens=False, |
|
) |
|
return prompt |
|
|
|
def _embed_image(self, pixel_values): |
|
with torch.amp.autocast_mode.autocast("cuda", enabled=True): |
|
vision_outputs = self.clip_model( |
|
pixel_values=pixel_values, output_hidden_states=True |
|
) |
|
image_features = vision_outputs.hidden_states |
|
embedded_images = self.image_adapter(image_features) |
|
embedded_images = embedded_images.to("cuda") |
|
return embedded_images |
|
|
|
def _construct_inputs(self, embedded_images, prompt): |
|
prompt_embeds = self.text_model.model.embed_tokens(prompt.to("cuda")) |
|
assert prompt_embeds.shape == ( |
|
1, |
|
prompt.shape[1], |
|
self.text_model.config.hidden_size, |
|
), ( |
|
f"Prompt shape is {prompt_embeds.shape}, expected " |
|
f"{(1, prompt.shape[1], self.text_model.config.hidden_size)}" |
|
) |
|
|
|
embedded_bos = self.text_model.model.embed_tokens( |
|
torch.tensor( |
|
[[self.tokenizer.bos_token_id]], |
|
device=self.text_model.device, |
|
dtype=torch.int64, |
|
) |
|
) |
|
|
|
eot_embed = ( |
|
self.image_adapter.get_eot_embedding() |
|
.unsqueeze(0) |
|
.to(dtype=self.text_model.dtype) |
|
) |
|
|
|
inputs_embeds = torch.cat( |
|
[ |
|
embedded_bos.expand(embedded_images.shape[0], -1, -1), |
|
embedded_images.to(dtype=embedded_bos.dtype), |
|
prompt_embeds.expand(embedded_images.shape[0], -1, -1), |
|
eot_embed.expand(embedded_images.shape[0], -1, -1), |
|
], |
|
dim=1, |
|
) |
|
|
|
input_ids = torch.cat( |
|
[ |
|
torch.tensor( |
|
[[self.tokenizer.bos_token_id]], dtype=torch.long |
|
), |
|
torch.zeros( |
|
(1, embedded_images.shape[1]), dtype=torch.long |
|
), |
|
prompt, |
|
torch.tensor( |
|
[[self.tokenizer.eos_token_id]], dtype=torch.long |
|
), |
|
], |
|
dim=1, |
|
).to("cuda") |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
return inputs_embeds, input_ids, attention_mask |
|
|
|
def _generate_caption(self, inputs_embeds, input_ids, attention_mask): |
|
generate_ids = self.text_model.generate( |
|
input_ids, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
max_new_tokens=300, |
|
do_sample=True, |
|
suppress_tokens=None, |
|
) |
|
return generate_ids |
|
|
|
def _decode_caption(self, generate_ids, input_ids): |
|
generate_ids = generate_ids[:, input_ids.shape[1]:] |
|
|
|
if (generate_ids[0][-1] == self.tokenizer.eos_token_id or |
|
generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids( |
|
"<|eot_id|>")): |
|
generate_ids = generate_ids[:, :-1] |
|
|
|
caption = self.tokenizer.batch_decode( |
|
generate_ids, |
|
skip_special_tokens=False, |
|
clean_up_tokenization_spaces=False |
|
)[0] |
|
return caption |
|
|
|
|
|
def main(): |
|
""" |
|
Generate captions for images in a directory |
|
and save them as .caption files. |
|
""" |
|
parser = argparse.ArgumentParser( |
|
description=( |
|
"Generate captions for images in a directory and save them as " |
|
".caption files." |
|
) |
|
) |
|
parser.add_argument( |
|
"directory", type=str, help="Target directory containing images." |
|
) |
|
parser.add_argument( |
|
"--caption_type", |
|
type=str, |
|
default="descriptive", |
|
choices=["descriptive", "training_prompt", "rng-tags", "custom"], |
|
help="Type of caption to generate.", |
|
) |
|
parser.add_argument( |
|
"--caption_tone", |
|
type=str, |
|
default="formal", |
|
choices=["formal", "informal"], |
|
help="Tone of the caption.", |
|
) |
|
parser.add_argument( |
|
"--caption_length", |
|
type=str, |
|
default="any", |
|
help="Length of the caption." |
|
) |
|
parser.add_argument( |
|
"--dont-strip-commas", |
|
action="store_true", |
|
help=( |
|
"If set, commas will not be stripped from the generated captions." |
|
), |
|
) |
|
parser.add_argument( |
|
"--custom_prompt", |
|
type=str, |
|
help=( |
|
"Custom prompt for the captioner. " |
|
"Use with --caption_type custom." |
|
), |
|
) |
|
parser.add_argument( |
|
"--add-commas-to-sentence-ends", |
|
action="store_true", |
|
help="Add commas after periods in sentences", |
|
) |
|
parser.add_argument( |
|
"--feed-from-tags", |
|
type=int, |
|
nargs="?", |
|
const=-1, |
|
help=( |
|
"Use .txt files with the same base filename " |
|
"as the images as input to the captioner. " |
|
"Optionally specify the number of tags to use." |
|
), |
|
) |
|
parser.add_argument( |
|
"--random-tags", |
|
type=int, |
|
help=( |
|
"Randomly select n number of tags. " |
|
"Only works if --feed-from-tags is enabled." |
|
), |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.random_tags is not None and args.feed_from_tags is None: |
|
parser.error( |
|
"--random-tags can only be used when --feed-from-tags is enabled" |
|
) |
|
|
|
print("Loading e621 tag data") |
|
tagset_normalizer = make_tagset_normalizer() |
|
|
|
|
|
joy_caption_model = JoyCaptionModel() |
|
joy_caption_model.load_models() |
|
|
|
|
|
if args.caption_type == "custom" and not args.custom_prompt: |
|
parser.error( |
|
"--custom_prompt is required when using --caption_type custom" |
|
) |
|
elif args.caption_type != "custom" and args.custom_prompt: |
|
parser.error( |
|
"--custom_prompt can only be used with --caption_type custom" |
|
) |
|
|
|
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"} |
|
for image_path in Path(args.directory).rglob("*"): |
|
if image_path.suffix.lower() in image_extensions: |
|
caption_file = image_path.with_suffix(".caption") |
|
|
|
|
|
if caption_file.exists(): |
|
print(f"Skipping {image_path}: Caption file already exists.") |
|
continue |
|
|
|
input_image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
custom_prompt = None |
|
if args.caption_type == "custom": |
|
custom_prompt = args.custom_prompt |
|
elif args.feed_from_tags is not None: |
|
custom_prompt = prompt_from_tags( |
|
args, image_path, tagset_normalizer |
|
) |
|
|
|
print(f"Custom prompt: {custom_prompt}") |
|
|
|
caption = joy_caption_model.generate_valid_caption( |
|
input_image, |
|
args.caption_type, |
|
args.caption_tone, |
|
args.caption_length, |
|
custom_prompt=custom_prompt, |
|
) |
|
|
|
|
|
if not args.dont_strip_commas: |
|
|
|
caption = re.sub(r",\s*([^\d])", r" \1", caption) |
|
|
|
|
|
if args.add_commas_to_sentence_ends: |
|
caption = re.sub(r"(\.)(\s+)([A-Z])", r"\1,\2\3", caption) |
|
|
|
|
|
caption = caption.replace("\n", " ") |
|
|
|
print(f"Caption for {image_path}:\n\n{caption}\n\n") |
|
|
|
|
|
with open(caption_file, "w", encoding="utf-8") as f: |
|
f.write(caption) |
|
print(f"Caption saved to {caption_file}") |
|
|
|
|
|
RE_PARENS_SUFFIX = re.compile(r"_\([^)]+\)$") |
|
E6DB_DATA = Path(__file__).resolve().parent / "data" |
|
|
|
|
|
def make_tagset_normalizer(): |
|
""" |
|
Create a TagSetNormalizer for encoding/decoding tags to and from integers. |
|
Configures it based on the provided config. |
|
""" |
|
|
|
tagset_normalizer = TagSetNormalizer(E6DB_DATA) |
|
|
|
tagid2cat = tagset_normalizer.tag_normalizer.tag_categories |
|
cat_artist = tag_category2id["artist"] |
|
cat2suffix = { |
|
tag_category2id["character"]: "_(character)", |
|
tag_category2id["lore"]: "_(lore)", |
|
tag_category2id["species"]: "_(species)", |
|
tag_category2id["copyright"]: "_(copyright)", |
|
} |
|
|
|
|
|
def input_map(tag, tid): |
|
|
|
|
|
without_suffix = RE_PARENS_SUFFIX.sub("", tag) |
|
had_suffix = tag != without_suffix |
|
if had_suffix: |
|
yield without_suffix |
|
|
|
|
|
cat = tagid2cat[tid] if tid is not None else -1 |
|
if cat == cat_artist: |
|
artist = without_suffix.removeprefix("by_") |
|
if artist != without_suffix: |
|
yield artist |
|
if not had_suffix: |
|
yield f"{artist}_(artist)" |
|
else: |
|
yield f"by_{artist}" |
|
if not had_suffix: |
|
yield f"by_{artist}_(artist)" |
|
elif not had_suffix: |
|
suffix = cat2suffix.get(cat) |
|
if suffix is not None: |
|
yield f"{without_suffix}{suffix}" |
|
|
|
|
|
if ":" in tag: |
|
yield tag.replace(":", "_") |
|
|
|
return tagset_normalizer.map_inputs(input_map, on_conflict="ignore") |
|
|
|
|
|
def format_nl_list(word_list): |
|
""" |
|
Takes a list of words and generates a natural language output. |
|
""" |
|
n = len(word_list) |
|
assert n > 0 |
|
if n == 1: |
|
return word_list[0] |
|
if n == 2: |
|
return f"{word_list[0]} and {word_list[1]}" |
|
|
|
*head, last = word_list |
|
return ", ".join(head) + ", and " + last |
|
|
|
|
|
TAG_SPECIES = tag_category2id["species"] |
|
TAG_CHARACTER = tag_category2id["character"] |
|
TAG_ARTIST = tag_category2id["artist"] |
|
TAG_COPYRIGHT = tag_category2id["copyright"] |
|
TAG_META = tag_category2id["meta"] |
|
TAG_FREQ_THRESH = 0 |
|
|
|
|
|
def prompt_from_tags(args, image_path: Path, |
|
tagset_normalizer: TagSetNormalizer): |
|
""" |
|
Generates a prompt from tags associated with the given image. |
|
|
|
Args: |
|
args: Additional arguments for the function. |
|
image_path (Path): |
|
The path to the image file. |
|
tagset_normalizer (TagSetNormalizer): |
|
An instance to normalize the tag set. |
|
|
|
Returns: |
|
None |
|
""" |
|
tag_file = find_tag_file(image_path) |
|
if tag_file is None: |
|
return None |
|
|
|
with open(tag_file, "r", encoding="utf-8") as f: |
|
tags = f.read().lower().split(",") |
|
|
|
tag_id_to_cat_id = tagset_normalizer.tag_normalizer.tag_categories |
|
encode = tagset_normalizer.tag_normalizer.encode |
|
|
|
|
|
tag_by_category: Dict[int, List[Tuple[int, str, int]]] = { |
|
cat: [] |
|
for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES] |
|
} |
|
other_tags: List[Tuple[int, str, int]] = [] |
|
implied: set = set() |
|
for tag in tags: |
|
tag = tag.strip() |
|
|
|
tag_id = encode(tag.replace(" ", "_")) |
|
if tag_id is None: |
|
other_tags.append((0, tag, 0)) |
|
implied.update(tagset_normalizer.implications_rej.get(0, ())) |
|
continue |
|
|
|
cat_id = tag_id_to_cat_id[tag_id] |
|
|
|
if cat_id == TAG_META: |
|
continue |
|
implied.update(tagset_normalizer.implications.get(tag_id, ())) |
|
|
|
freq = tag_rank_to_freq(tag_id) |
|
if freq < TAG_FREQ_THRESH: |
|
continue |
|
tag_by_category.get(cat_id, other_tags).append( |
|
(int(freq), tag, tag_id) |
|
) |
|
|
|
other_tags = sorted( |
|
(int(freq), tag, tag_id) |
|
for freq, tag, tag_id in other_tags |
|
if tag_id not in implied |
|
) |
|
|
|
for cat_id, cat_list in tag_by_category.items(): |
|
tag_by_category[cat_id] = sorted( |
|
(int(freq), tag, tag_id) |
|
for freq, tag, tag_id in cat_list |
|
if tag_id not in implied |
|
) |
|
|
|
if args.random_tags is not None: |
|
|
|
num_tags = min(args.random_tags, len(other_tags)) |
|
other_tags = random.sample( |
|
[ |
|
(i, tag, 0) |
|
for i, tag in enumerate(tags[: round(args.random_tags * 1.5)]) |
|
], |
|
num_tags, |
|
) |
|
elif args.feed_from_tags > 0: |
|
|
|
other_tags = other_tags[: args.feed_from_tags] |
|
|
|
|
|
artist_tag = tag_by_category[TAG_ARTIST] |
|
if artist_tag: |
|
artist_list = [str(tp[1]).removeprefix('by ') |
|
for tp in artist_tag[:4]] |
|
artist_txt = f"by {format_nl_list(artist_list)}" |
|
else: |
|
artist_txt = "" |
|
|
|
character_tag = tag_by_category[TAG_CHARACTER] |
|
if character_tag: |
|
tags = [tag for _, tag, _ in character_tag[:4]] |
|
character_txt = f"named {format_nl_list(tags)}" |
|
else: |
|
character_txt = "" |
|
|
|
species_tag = tag_by_category[TAG_SPECIES] |
|
if species_tag: |
|
species_txt = ( |
|
"of a " |
|
if len(character_tag) <= 1 and len(species_tag) <= 1 |
|
else "of " |
|
) |
|
species_txt += format_nl_list([tp[1] for tp in species_tag[:4]]) |
|
else: |
|
if character_tag: |
|
species_txt = ( |
|
" a character" if len(character_tag) <= 1 else " characters" |
|
) |
|
else: |
|
species_txt = "" |
|
|
|
copyright_tag = tag_by_category[TAG_COPYRIGHT] |
|
if copyright_tag: |
|
tags = [tag for _, tag, *_ in copyright_tag[:4]] |
|
copyright_txt = f"from {format_nl_list(tags)}" |
|
else: |
|
copyright_txt = "" |
|
tag_string = ", ".join(tp[1] for tp in other_tags) |
|
custom_prompt = ' '.join(s for s in [ |
|
"Write a descriptive caption for this image", |
|
artist_txt, species_txt, character_txt, copyright_txt, |
|
"in a formal tone. Use these tags to construct your caption:", |
|
tag_string, |
|
] if s) |
|
return custom_prompt |
|
|
|
|
|
def find_tag_file(image_path): |
|
""" |
|
Find the corresponding .txt file for the given image path. |
|
Handles cases where the image has a -(number) suffix. |
|
""" |
|
base_name = image_path.stem |
|
tag_file = image_path.with_suffix(".txt") |
|
|
|
if tag_file.exists(): |
|
return tag_file |
|
|
|
|
|
match = re.match(r"(.+)-\d+$", base_name) |
|
if match: |
|
base_name = match.group(1) |
|
tag_file = image_path.with_name(base_name).with_suffix(".txt") |
|
if tag_file.exists(): |
|
return tag_file |
|
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|