Spaces:
Running
Running
Upload 8 files
Browse files- tagger/fl2sd3longcap.py +9 -3
- tagger/tagger.py +11 -4
tagger/fl2sd3longcap.py
CHANGED
@@ -8,9 +8,13 @@ import subprocess
|
|
8 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
-
fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
|
12 |
-
fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def fl_modify_caption(caption: str) -> str:
|
16 |
"""
|
@@ -41,7 +45,7 @@ def fl_modify_caption(caption: str) -> str:
|
|
41 |
return modified_caption if modified_caption != caption else caption
|
42 |
|
43 |
|
44 |
-
@spaces.GPU
|
45 |
def fl_run_example(image):
|
46 |
task_prompt = "<DESCRIPTION>"
|
47 |
prompt = task_prompt + "Describe this image in great detail."
|
@@ -50,6 +54,7 @@ def fl_run_example(image):
|
|
50 |
if image.mode != "RGB":
|
51 |
image = image.convert("RGB")
|
52 |
|
|
|
53 |
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
54 |
generated_ids = fl_model.generate(
|
55 |
input_ids=inputs["input_ids"],
|
@@ -57,6 +62,7 @@ def fl_run_example(image):
|
|
57 |
max_new_tokens=1024,
|
58 |
num_beams=3
|
59 |
)
|
|
|
60 |
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
61 |
parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
|
62 |
return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
|
|
|
8 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
11 |
|
12 |
+
try:
|
13 |
+
fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to("cpu").eval()
|
14 |
+
fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
|
15 |
+
except Exception as e:
|
16 |
+
print(e)
|
17 |
+
fl_model = fl_processor = None
|
18 |
|
19 |
def fl_modify_caption(caption: str) -> str:
|
20 |
"""
|
|
|
45 |
return modified_caption if modified_caption != caption else caption
|
46 |
|
47 |
|
48 |
+
@spaces.GPU(duration=30)
|
49 |
def fl_run_example(image):
|
50 |
task_prompt = "<DESCRIPTION>"
|
51 |
prompt = task_prompt + "Describe this image in great detail."
|
|
|
54 |
if image.mode != "RGB":
|
55 |
image = image.convert("RGB")
|
56 |
|
57 |
+
fl_model.to(device)
|
58 |
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
59 |
generated_ids = fl_model.generate(
|
60 |
input_ids=inputs["input_ids"],
|
|
|
62 |
max_new_tokens=1024,
|
63 |
num_beams=3
|
64 |
)
|
65 |
+
fl_model.to("cpu")
|
66 |
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
67 |
parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
|
68 |
return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
|
tagger/tagger.py
CHANGED
@@ -12,10 +12,15 @@ from pathlib import Path
|
|
12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
21 |
return (
|
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
|
506 |
return ", ".join(all_tags)
|
507 |
|
508 |
|
509 |
-
@spaces.GPU()
|
510 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
511 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
512 |
|
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
|
|
514 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
515 |
|
516 |
# get probabilities
|
|
|
517 |
results = {
|
518 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
519 |
}
|
|
|
520 |
# rating, character, general
|
521 |
rating, character, general = postprocess_results(
|
522 |
results, general_threshold, character_threshold
|
|
|
12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
14 |
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
default_device = device
|
|
|
17 |
|
18 |
+
try:
|
19 |
+
wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
|
20 |
+
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
21 |
+
except Exception as e:
|
22 |
+
print(e)
|
23 |
+
wd_model = wd_processor = None
|
24 |
|
25 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
26 |
return (
|
|
|
511 |
return ", ".join(all_tags)
|
512 |
|
513 |
|
514 |
+
@spaces.GPU(duration=30)
|
515 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
516 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
517 |
|
|
|
519 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
520 |
|
521 |
# get probabilities
|
522 |
+
if device != default_device: wd_model.to(device=device)
|
523 |
results = {
|
524 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
525 |
}
|
526 |
+
if device != default_device: wd_model.to(device=default_device)
|
527 |
# rating, character, general
|
528 |
rating, character, general = postprocess_results(
|
529 |
results, general_threshold, character_threshold
|