k4d3 commited on
Commit
4a1fbfb
1 Parent(s): 0f0ad40

Signed-off-by: Balazs Horvath <acsipont@gmail.com>

Files changed (1) hide show
  1. joy +6 -0
joy CHANGED
@@ -32,6 +32,7 @@ from transformers import (
32
  PreTrainedTokenizerFast,
33
  )
34
  from torch import nn
 
35
 
36
  CLIP_PATH = "google/siglip-so400m-patch14-384"
37
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
@@ -80,6 +81,8 @@ CAPTION_TYPE_MAP = {
80
 
81
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
82
 
 
 
83
  class ImageAdapter(nn.Module):
84
  """
85
  Custom image adapter module for processing CLIP vision outputs.
@@ -466,6 +469,9 @@ def main():
466
  if args.random_tags is not None and args.feed_from_tags is None:
467
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
468
 
 
 
 
469
  # Initialize and load models
470
  joy_caption_model = JoyCaptionModel()
471
  joy_caption_model.load_models()
 
32
  PreTrainedTokenizerFast,
33
  )
34
  from torch import nn
35
+ from e6db_reader import TagNormalizer
36
 
37
  CLIP_PATH = "google/siglip-so400m-patch14-384"
38
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
 
81
 
82
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
83
 
84
+ E6DB_DATA = Path(__file__).resolve().parent / "data"
85
+
86
  class ImageAdapter(nn.Module):
87
  """
88
  Custom image adapter module for processing CLIP vision outputs.
 
469
  if args.random_tags is not None and args.feed_from_tags is None:
470
  parser.error("--random-tags can only be used when --feed-from-tags is enabled")
471
 
472
+ print('Loading e621 tag data')
473
+ tag_normalizer = TagNormalizer(E6DB_DATA)
474
+
475
  # Initialize and load models
476
  joy_caption_model = JoyCaptionModel()
477
  joy_caption_model.load_models()