John6666 commited on
Commit
f59d99b
·
verified ·
1 Parent(s): a4464c2

Upload florence2_sd3_tagger8.py

Browse files
Files changed (1) hide show
  1. florence2_sd3_tagger8.py +120 -0
florence2_sd3_tagger8.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ from PIL import Image
5
+
6
+ import logging
7
+ logger = logging.getLogger(__name__)
8
+
9
+ DEFAULT_FLORENCE2_SD3_CAP_REPO = 'John6666/gokaygokay-Florence-2-SD3-Captioner-8bit'
10
+
11
+ def fl_modify_caption(caption: str) -> str:
12
+ """
13
+ Removes specific prefixes from captions if present, otherwise returns the original caption.
14
+ Args:
15
+ caption (str): A string containing a caption.
16
+ Returns:
17
+ str: The caption with the prefix removed if it was present, or the original caption.
18
+ """
19
+ # Define the prefixes to remove
20
+ prefix_substrings = [
21
+ ('captured from ', ''),
22
+ ('captured at ', '')
23
+ ]
24
+
25
+ # Create a regex pattern to match any of the prefixes
26
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
27
+ replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
28
+
29
+ # Function to replace matched prefix with its corresponding replacement
30
+ def replace_fn(match):
31
+ return replacers[match.group(0).lower()]
32
+
33
+ # Apply the regex to the caption
34
+ modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
35
+
36
+ # If the caption was modified, return the modified version; otherwise, return the original
37
+ return modified_caption if modified_caption != caption else caption
38
+
39
+
40
+ def fl_run_example(image, fl_model, fl_processor):
41
+ image = Image.open(image)
42
+ task_prompt = "<DESCRIPTION>"
43
+ prompt = task_prompt + "Describe this image in great detail."
44
+
45
+ # Ensure the image is in RGB mode
46
+ if image.mode != "RGB":
47
+ image = image.convert("RGB")
48
+
49
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to("cuda")
50
+ generated_ids = fl_model.generate(
51
+ input_ids=inputs["input_ids"],
52
+ pixel_values=inputs["pixel_values"],
53
+ max_new_tokens=1024,
54
+ num_beams=3
55
+ )
56
+ generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
57
+ parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
58
+ return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
59
+
60
+
61
+ def predict_tags_fl2_sd3(image, fl_model, fl_processor):
62
+ tag = fl_run_example(image, fl_model, fl_processor)
63
+ return tag
64
+
65
+
66
+ def main(args):
67
+ # model location is model_dir + repo_id
68
+ # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
69
+ model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
70
+
71
+ if not os.path.exists(model_location) or args.force_download:
72
+ os.makedirs(args.model_dir, exist_ok=True)
73
+ logger.info(f"downloading Florence-2-SD3-Captioner model from hf_hub. id: {args.repo_id}")
74
+ from huggingface_hub import snapshot_download
75
+ snapshot_download(repo_id=args.repo_id, local_dir=model_location, local_dir_use_symlinks=False)
76
+ else:
77
+ logger.info("using existing Florence-2-SD3-Captioner model")
78
+
79
+ from transformers import AutoProcessor, AutoModelForCausalLM
80
+ import torch
81
+ fl_model = AutoModelForCausalLM.from_pretrained(f"{model_location}", torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True)
82
+ fl_processor = AutoProcessor.from_pretrained(f"{model_location}", trust_remote_code=True)
83
+
84
+ image_path = args.image_path
85
+
86
+ tag = predict_tags_fl2_sd3(image_path, fl_model, fl_processor)
87
+
88
+ print(tag)
89
+
90
+
91
+ def setup_parser() -> argparse.ArgumentParser:
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument('image_path')
94
+ parser.add_argument(
95
+ "--repo_id",
96
+ type=str,
97
+ default=DEFAULT_FLORENCE2_SD3_CAP_REPO,
98
+ help="repo id for gokaygokay's Florence-2-SD3-Captioner on Hugging Face",
99
+ )
100
+ parser.add_argument(
101
+ "--model_dir",
102
+ type=str,
103
+ default="Florence-2-SD3-Captioner_model",
104
+ help="directory to store Florence-2-SD3-Captioner model",
105
+ )
106
+ parser.add_argument(
107
+ "--force_download",
108
+ action="store_true",
109
+ help="force downloading Florence-2-SD3-Captioner model",
110
+ )
111
+
112
+ return parser
113
+
114
+
115
+ if __name__ == "__main__":
116
+ parser = setup_parser()
117
+
118
+ args = parser.parse_args()
119
+
120
+ main(args)