VikramSingh178
commited on
Commit
•
ebbf256
1
Parent(s):
ca2a4e0
refactor: Update SDXL-LoRA inference pipeline to load multiple adapter weights
Browse files
product_diffusion_api/routers/sdxl_text_to_image.py
CHANGED
@@ -82,11 +82,13 @@ def pil_to_s3_json(image: Image.Image,file_name) -> str:
|
|
82 |
|
83 |
|
84 |
@lru_cache(maxsize=1)
|
85 |
-
def load_pipeline(model_name, adapter_name):
|
86 |
pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(
|
87 |
"cuda"
|
88 |
)
|
89 |
pipe.load_lora_weights(adapter_name)
|
|
|
|
|
90 |
pipe.unload_lora_weights()
|
91 |
pipe.unet.to(memory_format=torch.channels_last)
|
92 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
@@ -96,7 +98,7 @@ def load_pipeline(model_name, adapter_name):
|
|
96 |
return pipe
|
97 |
|
98 |
|
99 |
-
loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME)
|
100 |
|
101 |
|
102 |
# SDXLLoraInference class for running inference
|
|
|
82 |
|
83 |
|
84 |
@lru_cache(maxsize=1)
|
85 |
+
def load_pipeline(model_name, adapter_name,adapter_name_2):
|
86 |
pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(
|
87 |
"cuda"
|
88 |
)
|
89 |
pipe.load_lora_weights(adapter_name)
|
90 |
+
pipe.load_lora_weights(adapter_name_2)
|
91 |
+
pipe.set_adapters([adapter_name, adapter_name_2], adapter_weights=[0.7, 0.8])
|
92 |
pipe.unload_lora_weights()
|
93 |
pipe.unet.to(memory_format=torch.channels_last)
|
94 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
|
|
98 |
return pipe
|
99 |
|
100 |
|
101 |
+
loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME,config.ADAPTER_NAME_2)
|
102 |
|
103 |
|
104 |
# SDXLLoraInference class for running inference
|
scripts/config.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
2 |
ADAPTER_NAME = "VikramSingh178/sdxl-lora-finetune-product-caption"
|
|
|
3 |
VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
|
4 |
DATASET_NAME= "hahminlew/kream-product-blip-captions"
|
5 |
PROJECT_NAME = "Product Photography"
|
6 |
-
PRODUCTS_10k_DATASET = "
|
7 |
CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
|
8 |
|
9 |
|
@@ -14,7 +15,7 @@ class Config:
|
|
14 |
self.pretrained_vae_model_name_or_path = VAE_NAME
|
15 |
self.revision = None
|
16 |
self.variant = None
|
17 |
-
self.dataset_name =
|
18 |
self.dataset_config_name = None
|
19 |
self.train_data_dir = None
|
20 |
self.image_column = 'image'
|
|
|
1 |
MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
2 |
ADAPTER_NAME = "VikramSingh178/sdxl-lora-finetune-product-caption"
|
3 |
+
ADAPTER_NAME_2 = "VikramSingh178/Products10k-SDXL-Lora"
|
4 |
VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
|
5 |
DATASET_NAME= "hahminlew/kream-product-blip-captions"
|
6 |
PROJECT_NAME = "Product Photography"
|
7 |
+
PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
|
8 |
CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
|
9 |
|
10 |
|
|
|
15 |
self.pretrained_vae_model_name_or_path = VAE_NAME
|
16 |
self.revision = None
|
17 |
self.variant = None
|
18 |
+
self.dataset_name = PRODUCTS_10k_DATASET
|
19 |
self.dataset_config_name = None
|
20 |
self.train_data_dir = None
|
21 |
self.image_column = 'image'
|
scripts/sdxl_lora_tuner.py
CHANGED
@@ -51,16 +51,8 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|
51 |
logger = get_logger(__name__)
|
52 |
|
53 |
|
|
|
54 |
def save_model_card(
|
55 |
-
repo_id: str,
|
56 |
-
images: list = None,
|
57 |
-
base_model: str = None,
|
58 |
-
dataset_name: str = None,
|
59 |
-
train_text_encoder: bool = False,
|
60 |
-
repo_folder: str = None,
|
61 |
-
vae_path: str = None,
|
62 |
-
):
|
63 |
-
def save_model_card(
|
64 |
repo_id: str,
|
65 |
images: list = None,
|
66 |
base_model: str = None,
|
@@ -533,7 +525,7 @@ def main():
|
|
533 |
else:
|
534 |
data_files = {}
|
535 |
if config.train_data_dir is not None:
|
536 |
-
data_files["
|
537 |
dataset = load_dataset(
|
538 |
"imagefolder",
|
539 |
data_files=data_files,
|
@@ -544,7 +536,7 @@ def main():
|
|
544 |
|
545 |
# Preprocessing the datasets.
|
546 |
# We need to tokenize inputs and targets.
|
547 |
-
column_names = dataset["
|
548 |
|
549 |
# 6. Get the column names for input/target.
|
550 |
DATASET_NAME_MAPPING = {
|
@@ -651,13 +643,13 @@ def main():
|
|
651 |
|
652 |
with accelerator.main_process_first():
|
653 |
if config.max_train_samples is not None:
|
654 |
-
dataset["
|
655 |
-
dataset["
|
656 |
.shuffle(seed=config.seed)
|
657 |
.select(range(config.max_train_samples))
|
658 |
)
|
659 |
# Set the training transforms
|
660 |
-
train_dataset = dataset["
|
661 |
preprocess_train, output_all_columns=True
|
662 |
)
|
663 |
|
|
|
51 |
logger = get_logger(__name__)
|
52 |
|
53 |
|
54 |
+
|
55 |
def save_model_card(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
repo_id: str,
|
57 |
images: list = None,
|
58 |
base_model: str = None,
|
|
|
525 |
else:
|
526 |
data_files = {}
|
527 |
if config.train_data_dir is not None:
|
528 |
+
data_files["test"] = os.path.join(config.train_data_dir, "**")
|
529 |
dataset = load_dataset(
|
530 |
"imagefolder",
|
531 |
data_files=data_files,
|
|
|
536 |
|
537 |
# Preprocessing the datasets.
|
538 |
# We need to tokenize inputs and targets.
|
539 |
+
column_names = dataset["test"].column_names
|
540 |
|
541 |
# 6. Get the column names for input/target.
|
542 |
DATASET_NAME_MAPPING = {
|
|
|
643 |
|
644 |
with accelerator.main_process_first():
|
645 |
if config.max_train_samples is not None:
|
646 |
+
dataset["test"] = (
|
647 |
+
dataset["test"]
|
648 |
.shuffle(seed=config.seed)
|
649 |
.select(range(config.max_train_samples))
|
650 |
)
|
651 |
# Set the training transforms
|
652 |
+
train_dataset = dataset["test"].with_transform(
|
653 |
preprocess_train, output_all_columns=True
|
654 |
)
|
655 |
|