VikramSingh178 commited on
Commit
ebbf256
1 Parent(s): ca2a4e0

refactor: Update SDXL-LoRA inference pipeline to load multiple adapter weights

Browse files
product_diffusion_api/routers/sdxl_text_to_image.py CHANGED
@@ -82,11 +82,13 @@ def pil_to_s3_json(image: Image.Image,file_name) -> str:
82
 
83
 
84
  @lru_cache(maxsize=1)
85
- def load_pipeline(model_name, adapter_name):
86
  pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(
87
  "cuda"
88
  )
89
  pipe.load_lora_weights(adapter_name)
 
 
90
  pipe.unload_lora_weights()
91
  pipe.unet.to(memory_format=torch.channels_last)
92
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
@@ -96,7 +98,7 @@ def load_pipeline(model_name, adapter_name):
96
  return pipe
97
 
98
 
99
- loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME)
100
 
101
 
102
  # SDXLLoraInference class for running inference
 
82
 
83
 
84
  @lru_cache(maxsize=1)
85
+ def load_pipeline(model_name, adapter_name,adapter_name_2):
86
  pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(
87
  "cuda"
88
  )
89
  pipe.load_lora_weights(adapter_name)
90
+ pipe.load_lora_weights(adapter_name_2)
91
+ pipe.set_adapters([adapter_name, adapter_name_2], adapter_weights=[0.7, 0.8])
92
  pipe.unload_lora_weights()
93
  pipe.unet.to(memory_format=torch.channels_last)
94
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
 
98
  return pipe
99
 
100
 
101
+ loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME,config.ADAPTER_NAME_2)
102
 
103
 
104
  # SDXLLoraInference class for running inference
scripts/config.py CHANGED
@@ -1,9 +1,10 @@
1
  MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
2
  ADAPTER_NAME = "VikramSingh178/sdxl-lora-finetune-product-caption"
 
3
  VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
4
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
5
  PROJECT_NAME = "Product Photography"
6
- PRODUCTS_10k_DATASET = "amaye15/Products-10k"
7
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
8
 
9
 
@@ -14,7 +15,7 @@ class Config:
14
  self.pretrained_vae_model_name_or_path = VAE_NAME
15
  self.revision = None
16
  self.variant = None
17
- self.dataset_name = DATASET_NAME
18
  self.dataset_config_name = None
19
  self.train_data_dir = None
20
  self.image_column = 'image'
 
1
  MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
2
  ADAPTER_NAME = "VikramSingh178/sdxl-lora-finetune-product-caption"
3
+ ADAPTER_NAME_2 = "VikramSingh178/Products10k-SDXL-Lora"
4
  VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
5
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
6
  PROJECT_NAME = "Product Photography"
7
+ PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
8
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
 
10
 
 
15
  self.pretrained_vae_model_name_or_path = VAE_NAME
16
  self.revision = None
17
  self.variant = None
18
+ self.dataset_name = PRODUCTS_10k_DATASET
19
  self.dataset_config_name = None
20
  self.train_data_dir = None
21
  self.image_column = 'image'
scripts/sdxl_lora_tuner.py CHANGED
@@ -51,16 +51,8 @@ from diffusers.utils.torch_utils import is_compiled_module
51
  logger = get_logger(__name__)
52
 
53
 
 
54
  def save_model_card(
55
- repo_id: str,
56
- images: list = None,
57
- base_model: str = None,
58
- dataset_name: str = None,
59
- train_text_encoder: bool = False,
60
- repo_folder: str = None,
61
- vae_path: str = None,
62
- ):
63
- def save_model_card(
64
  repo_id: str,
65
  images: list = None,
66
  base_model: str = None,
@@ -533,7 +525,7 @@ def main():
533
  else:
534
  data_files = {}
535
  if config.train_data_dir is not None:
536
- data_files["train"] = os.path.join(config.train_data_dir, "**")
537
  dataset = load_dataset(
538
  "imagefolder",
539
  data_files=data_files,
@@ -544,7 +536,7 @@ def main():
544
 
545
  # Preprocessing the datasets.
546
  # We need to tokenize inputs and targets.
547
- column_names = dataset["train"].column_names
548
 
549
  # 6. Get the column names for input/target.
550
  DATASET_NAME_MAPPING = {
@@ -651,13 +643,13 @@ def main():
651
 
652
  with accelerator.main_process_first():
653
  if config.max_train_samples is not None:
654
- dataset["train"] = (
655
- dataset["train"]
656
  .shuffle(seed=config.seed)
657
  .select(range(config.max_train_samples))
658
  )
659
  # Set the training transforms
660
- train_dataset = dataset["train"].with_transform(
661
  preprocess_train, output_all_columns=True
662
  )
663
 
 
51
  logger = get_logger(__name__)
52
 
53
 
54
+
55
  def save_model_card(
 
 
 
 
 
 
 
 
 
56
  repo_id: str,
57
  images: list = None,
58
  base_model: str = None,
 
525
  else:
526
  data_files = {}
527
  if config.train_data_dir is not None:
528
+ data_files["test"] = os.path.join(config.train_data_dir, "**")
529
  dataset = load_dataset(
530
  "imagefolder",
531
  data_files=data_files,
 
536
 
537
  # Preprocessing the datasets.
538
  # We need to tokenize inputs and targets.
539
+ column_names = dataset["test"].column_names
540
 
541
  # 6. Get the column names for input/target.
542
  DATASET_NAME_MAPPING = {
 
643
 
644
  with accelerator.main_process_first():
645
  if config.max_train_samples is not None:
646
+ dataset["test"] = (
647
+ dataset["test"]
648
  .shuffle(seed=config.seed)
649
  .select(range(config.max_train_samples))
650
  )
651
  # Set the training transforms
652
+ train_dataset = dataset["test"].with_transform(
653
  preprocess_train, output_all_columns=True
654
  )
655