import json from collections import defaultdict import safetensors import timm from transformers import AutoProcessor import gradio as gr import torch import time from florence2_implementation.modeling_florence2 import Florence2ForConditionalGeneration from torchvision.transforms import InterpolationMode from PIL import Image import torchvision.transforms.functional as TF from torchvision.transforms import transforms import random import csv import os torch.set_grad_enabled(False) # HF now (Feb 20, 2025) impose storage limit of 1GB. Will have to pull JTP from other places. os.system("wget -nv https://huggingface.co/spaces/RedRocket/JointTaggerProject-Inference-Beta/resolve/main/JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors") category_id_to_str = { "0": "general", # 3 copyright "4": "character", "5": "species", "7": "meta", "8": "lore", "1": "artist", } class Pruner: def __init__(self, path_to_tag_list_csv): species_tags = set() allowed_tags = set() with open(path_to_tag_list_csv, "r") as f: reader = csv.reader(f) header = next(reader) name_index = header.index("name") category_index = header.index("category") post_count_index = header.index("post_count") for row in reader: if int(row[post_count_index]) > 20: category = row[category_index] name = row[name_index] if category == "5": species_tags.add(name) allowed_tags.add(name) elif category == "0": allowed_tags.add(name) elif category == "7": allowed_tags.add(name) self.species_tags = species_tags self.allowed_tags = allowed_tags def _prune_not_allowed_tags(self, raw_tags): this_allowed_tags = set() for tag in raw_tags: if tag in self.allowed_tags: this_allowed_tags.add(tag) return this_allowed_tags def _find_and_format_species_tags(self, tag_set): this_specie_tags = [] for tag in tag_set: if tag in self.species_tags: this_specie_tags.append(tag) formatted_tags = f"species: {' '.join([t for t in this_specie_tags])}\n" return formatted_tags, this_specie_tags def prompt_construction_pipeline_florence2(self, tags, length): if type(tags) is str: tags = tags.split(" ") random.shuffle(tags) tags = self._prune_not_allowed_tags(tags, ) formatted_species_tags, this_specie_tags = self._find_and_format_species_tags(tags) non_species_tags = [t for t in tags if t not in this_specie_tags] prompt = f"{' '.join(non_species_tags)}\n{formatted_species_tags}\nlength: {length}\n\nSTYLE1 FURRY CAPTION:" return prompt class Fit(torch.nn.Module): def __init__( self, bounds: tuple[int, int] | int, interpolation=InterpolationMode.LANCZOS, grow: bool = True, pad: float | None = None ): super().__init__() self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds self.interpolation = interpolation self.grow = grow self.pad = pad def forward(self, img: Image) -> Image: wimg, himg = img.size hbound, wbound = self.bounds hscale = hbound / himg wscale = wbound / wimg if not self.grow: hscale = min(hscale, 1.0) wscale = min(wscale, 1.0) scale = min(hscale, wscale) if scale == 1.0: return img hnew = min(round(himg * scale), hbound) wnew = min(round(wimg * scale), wbound) img = TF.resize(img, (hnew, wnew), self.interpolation) if self.pad is None: return img hpad = hbound - hnew wpad = wbound - wnew tpad = hpad // 2 bpad = hpad - tpad lpad = wpad // 2 rpad = wpad - lpad return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" + f"bounds={self.bounds}, " + f"interpolation={self.interpolation.value}, " + f"grow={self.grow}, " + f"pad={self.pad})" ) class CompositeAlpha(torch.nn.Module): def __init__( self, background: tuple[float, float, float] | float, ): super().__init__() self.background = (background, background, background) if isinstance(background, float) else background self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2) def forward(self, img: torch.Tensor) -> torch.Tensor: if img.shape[-3] == 3: return img alpha = img[..., 3, None, :, :] img[..., :3, :, :] *= alpha background = self.background.expand(-1, img.shape[-2], img.shape[-1]) if background.ndim == 1: background = background[:, None, None] elif background.ndim == 2: background = background[None, :, :] img[..., :3, :, :] += (1.0 - alpha) * background return img[..., :3, :, :] def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" + f"background={self.background})" ) class GatedHead(torch.nn.Module): def __init__(self, num_features: int, num_classes: int ): super().__init__() self.num_classes = num_classes self.linear = torch.nn.Linear(num_features, num_classes * 2) self.act = torch.nn.Sigmoid() self.gate = torch.nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear(x) x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:]) return x model_id = "lodestone-horizon/furrence2-large" model = Florence2ForConditionalGeneration.from_pretrained(model_id,).eval() processor = AutoProcessor.from_pretrained("./florence2_implementation/", trust_remote_code=True) tree = defaultdict(list) with open('tag_implications-2024-05-05.csv', 'rt') as csvfile: reader = csv.DictReader(csvfile) for row in reader: if row["status"] == "active": tree[row["consequent_name"]].append(row["antecedent_name"]) title = """

Furrence2 Captioner Demo

""" description=( """
The captioner is being prompted by JTP Pilot2 tagger. You may use hand-curated tags to get better results.
This demo is running on CPU. For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""" ) tagger_transform = transforms.Compose([ Fit((384, 384)), transforms.ToTensor(), CompositeAlpha(0.5), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), transforms.CenterCrop((384, 384)), ]) THRESHOLD = 0.2 tagger_model = timm.create_model( "vit_so400m_patch14_siglip_384.webli", pretrained=False, num_classes=9083, ) # type: VisionTransformer tagger_model.head = GatedHead(min(tagger_model.head.weight.shape), 9083) safetensors.torch.load_model(tagger_model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors") tagger_model.eval() with open("JTP_PILOT2_tags.json", "r") as file: tags = json.load(file) # type: dict allowed_tags = list(tags.keys()) for idx, tag in enumerate(allowed_tags): allowed_tags[idx] = tag pruner = Pruner("tags-2024-05-05.csv") def generate_prompt(image, expected_caption_length): global THRESHOLD, tree, tokenizer, model, tagger_model, tagger_transform tagger_input = tagger_transform(image.convert('RGBA')).unsqueeze(0) probabilities = tagger_model(tagger_input) for prob in probabilities: indices = torch.where(prob > THRESHOLD)[0] sorted_indices = torch.argsort(prob[indices], descending=True) final_tags = [] for i in sorted_indices: final_tags.append(allowed_tags[indices[i]]) final_tags = " ".join(final_tags) task_prompt = pruner.prompt_construction_pipeline_florence2(final_tags, expected_caption_length) return task_prompt def inference_caption(image, expected_caption_length, seq_len=512,): start_time = time.time() prompt_input = generate_prompt(image, expected_caption_length) end_time = time.time() execution_time = end_time - start_time print(f"Finished tagging in {execution_time:.3f} seconds") try: pixel_values = processor.image_processor(image, return_tensors="pt", )["pixel_values"] encoder_inputs = processor.tokenizer( text=prompt_input, return_tensors="pt", # padding = "max_length", # truncation = True, # max_length = 256, # don't add these; these will cause problems when doing inference ) start_time = time.time() generated_ids = model.generate( input_ids=encoder_inputs["input_ids"], attention_mask=encoder_inputs["attention_mask"], pixel_values=pixel_values, max_new_tokens=seq_len, early_stopping=False, do_sample=False, num_beams=3, ) end_time = time.time() execution_time = end_time - start_time print(f"Finished captioning in {execution_time:.3f} seconds") generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text except Exception as e: print("error message:", e) return "An error occurred." def main(): with gr.Blocks() as iface: gr.Markdown(title) gr.Markdown(description) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil") seq_len = gr.Number( value=512, label="Output Cutoff Length", precision=0, interactive=True ) expected_length = gr.Number(minimum=50, maximum=200, value=100, label="Expected Caption Length", precision=0, interactive=True ) with gr.Column(scale=1): with gr.Column(): caption_button = gr.Button( value="Caption it!", interactive=True, variant="primary", ) caption_output = gr.Textbox(lines=1, label="Caption Output") caption_button.click( inference_caption, [ image_input, expected_length, seq_len, ], [caption_output,], ) iface.launch(share=False) if __name__ == "__main__": main()