John6666 commited on
Commit
5f0104b
β€’
1 Parent(s): 31b1b9f

Upload 8 files

Browse files
Files changed (2) hide show
  1. tagger/fl2sd3longcap.py +9 -3
  2. tagger/tagger.py +11 -4
tagger/fl2sd3longcap.py CHANGED
@@ -8,9 +8,13 @@ import subprocess
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
12
- fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
13
 
 
 
 
 
 
 
14
 
15
  def fl_modify_caption(caption: str) -> str:
16
  """
@@ -41,7 +45,7 @@ def fl_modify_caption(caption: str) -> str:
41
  return modified_caption if modified_caption != caption else caption
42
 
43
 
44
- @spaces.GPU
45
  def fl_run_example(image):
46
  task_prompt = "<DESCRIPTION>"
47
  prompt = task_prompt + "Describe this image in great detail."
@@ -50,6 +54,7 @@ def fl_run_example(image):
50
  if image.mode != "RGB":
51
  image = image.convert("RGB")
52
 
 
53
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
54
  generated_ids = fl_model.generate(
55
  input_ids=inputs["input_ids"],
@@ -57,6 +62,7 @@ def fl_run_example(image):
57
  max_new_tokens=1024,
58
  num_beams=3
59
  )
 
60
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
61
  parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
62
  return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
 
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
 
12
+ try:
13
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to("cpu").eval()
14
+ fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
15
+ except Exception as e:
16
+ print(e)
17
+ fl_model = fl_processor = None
18
 
19
  def fl_modify_caption(caption: str) -> str:
20
  """
 
45
  return modified_caption if modified_caption != caption else caption
46
 
47
 
48
+ @spaces.GPU(duration=30)
49
  def fl_run_example(image):
50
  task_prompt = "<DESCRIPTION>"
51
  prompt = task_prompt + "Describe this image in great detail."
 
54
  if image.mode != "RGB":
55
  image = image.convert("RGB")
56
 
57
+ fl_model.to(device)
58
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
59
  generated_ids = fl_model.generate(
60
  input_ids=inputs["input_ids"],
 
62
  max_new_tokens=1024,
63
  num_beams=3
64
  )
65
+ fl_model.to("cpu")
66
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
67
  parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
68
  return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
tagger/tagger.py CHANGED
@@ -12,10 +12,15 @@ from pathlib import Path
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
- wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
16
- wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
17
- wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
18
 
 
 
 
 
 
 
19
 
20
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
21
  return (
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
506
  return ", ".join(all_tags)
507
 
508
 
509
- @spaces.GPU()
510
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
511
  inputs = wd_processor.preprocess(image, return_tensors="pt")
512
 
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
514
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
515
 
516
  # get probabilities
 
517
  results = {
518
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
519
  }
 
520
  # rating, character, general
521
  rating, character, general = postprocess_results(
522
  results, general_threshold, character_threshold
 
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ default_device = device
 
17
 
18
+ try:
19
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
20
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
21
+ except Exception as e:
22
+ print(e)
23
+ wd_model = wd_processor = None
24
 
25
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
26
  return (
 
511
  return ", ".join(all_tags)
512
 
513
 
514
+ @spaces.GPU(duration=30)
515
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
516
  inputs = wd_processor.preprocess(image, return_tensors="pt")
517
 
 
519
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
520
 
521
  # get probabilities
522
+ if device != default_device: wd_model.to(device=device)
523
  results = {
524
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
525
  }
526
+ if device != default_device: wd_model.to(device=default_device)
527
  # rating, character, general
528
  rating, character, general = postprocess_results(
529
  results, general_threshold, character_threshold