VikramSingh178 commited on
Commit
3e01790
1 Parent(s): 5e29265

Update SDXL-LoRA inference pipeline and model weights

Browse files

Former-commit-id: 550c615e6a453f0586ab834a0366c230320361d5

product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc CHANGED
Binary files a/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc and b/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc differ
 
product_diffusion_api/routers/sdxl_text_to_image.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from fastapi import APIRouter, HTTPException
2
  from pydantic import BaseModel
3
  import base64
@@ -6,9 +10,17 @@ from typing import List
6
  import uuid
7
  from diffusers import DiffusionPipeline
8
  import torch
 
 
 
 
 
 
 
9
 
10
  router = APIRouter()
11
 
 
12
  # Utility function to convert PIL image to base64 encoded JSON
13
  def pil_to_b64_json(image):
14
  # Generate a UUID for the image
@@ -19,6 +31,27 @@ def pil_to_b64_json(image):
19
  return {"image_id": image_id, "b64_image": b64_image}
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # SDXLLoraInference class for running inference
23
  class SDXLLoraInference:
24
  """
@@ -51,12 +84,7 @@ class SDXLLoraInference:
51
  num_inference_steps: int,
52
  guidance_scale: float,
53
  ) -> None:
54
- self.pipe = DiffusionPipeline.from_pretrained(
55
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
56
- )
57
- self.model_path = "VikramSingh178/sdxl-lora-finetune-product-caption"
58
- self.pipe.load_lora_weights(self.model_path)
59
- self.pipe.to('cuda')
60
  self.prompt = prompt
61
  self.negative_prompt = negative_prompt
62
  self.num_images = num_images
@@ -79,6 +107,7 @@ class SDXLLoraInference:
79
  ).images[0]
80
  return pil_to_b64_json(image)
81
 
 
82
  # Input format for single request
83
  class InputFormat(BaseModel):
84
  prompt: str
@@ -87,10 +116,12 @@ class InputFormat(BaseModel):
87
  negative_prompt: str
88
  num_images: int
89
 
 
90
  # Input format for batch requests
91
  class BatchInputFormat(BaseModel):
92
  batch_input: List[InputFormat]
93
 
 
94
  # Endpoint for single request
95
  @router.post("/sdxl_v0_lora_inference")
96
  async def sdxl_v0_lora_inference(data: InputFormat):
@@ -104,6 +135,7 @@ async def sdxl_v0_lora_inference(data: InputFormat):
104
  output_json = inference.run_inference()
105
  return output_json
106
 
 
107
  # Endpoint for batch requests
108
  @router.post("/sdxl_v0_lora_inference/batch")
109
  async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
@@ -122,7 +154,10 @@ async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
122
  MAX_QUEUE_SIZE = 64
123
 
124
  if len(data.batch_input) > MAX_QUEUE_SIZE:
125
- raise HTTPException(status_code=400, detail=f"Number of requests exceeds maximum queue size ({MAX_QUEUE_SIZE})")
 
 
 
126
 
127
  processed_requests = []
128
  for item in data.batch_input:
 
1
+ import sys
2
+
3
+ sys.path.append("../scripts") # Path of the scripts directory
4
+ import config
5
  from fastapi import APIRouter, HTTPException
6
  from pydantic import BaseModel
7
  import base64
 
10
  import uuid
11
  from diffusers import DiffusionPipeline
12
  import torch
13
+ import torch_tensorrt
14
+ from functools import lru_cache
15
+
16
+ torch._inductor.config.conv_1x1_as_mm = True
17
+ torch._inductor.config.coordinate_descent_tuning = True
18
+ torch._inductor.config.epilogue_fusion = False
19
+ torch._inductor.config.coordinate_descent_check_all_directions = True
20
 
21
  router = APIRouter()
22
 
23
+
24
  # Utility function to convert PIL image to base64 encoded JSON
25
  def pil_to_b64_json(image):
26
  # Generate a UUID for the image
 
31
  return {"image_id": image_id, "b64_image": b64_image}
32
 
33
 
34
+ @lru_cache(maxsize=1)
35
+ def load_pipeline(model_name, adapter_name):
36
+ pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(
37
+ "cuda"
38
+ )
39
+ pipe.load_lora_weights(adapter_name)
40
+ pipe.unet.to(memory_format=torch.channels_last)
41
+ pipe.vae.to(memory_format=torch.channels_last)
42
+ # pipe.unet = torch.compile(
43
+ # pipe.unet,
44
+ # mode = 'max-autotime'
45
+ # )
46
+
47
+ pipe.fuse_qkv_projections()
48
+
49
+ return pipe
50
+
51
+
52
+ loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME)
53
+
54
+
55
  # SDXLLoraInference class for running inference
56
  class SDXLLoraInference:
57
  """
 
84
  num_inference_steps: int,
85
  guidance_scale: float,
86
  ) -> None:
87
+ self.pipe = loaded_pipeline
 
 
 
 
 
88
  self.prompt = prompt
89
  self.negative_prompt = negative_prompt
90
  self.num_images = num_images
 
107
  ).images[0]
108
  return pil_to_b64_json(image)
109
 
110
+
111
  # Input format for single request
112
  class InputFormat(BaseModel):
113
  prompt: str
 
116
  negative_prompt: str
117
  num_images: int
118
 
119
+
120
  # Input format for batch requests
121
  class BatchInputFormat(BaseModel):
122
  batch_input: List[InputFormat]
123
 
124
+
125
  # Endpoint for single request
126
  @router.post("/sdxl_v0_lora_inference")
127
  async def sdxl_v0_lora_inference(data: InputFormat):
 
135
  output_json = inference.run_inference()
136
  return output_json
137
 
138
+
139
  # Endpoint for batch requests
140
  @router.post("/sdxl_v0_lora_inference/batch")
141
  async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
 
154
  MAX_QUEUE_SIZE = 64
155
 
156
  if len(data.batch_input) > MAX_QUEUE_SIZE:
157
+ raise HTTPException(
158
+ status_code=400,
159
+ detail=f"Number of requests exceeds maximum queue size ({MAX_QUEUE_SIZE})",
160
+ )
161
 
162
  processed_requests = []
163
  for item in data.batch_input:
scripts/__init__.py ADDED
File without changes
scripts/__pycache__/config.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
 
scripts/config.py CHANGED
@@ -1,4 +1,5 @@
1
  MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
 
2
  VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
3
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
4
  PROJECT_NAME = "Product Photography"
 
1
  MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
2
+ ADAPTER_NAME = "VikramSingh178/sdxl-lora-finetune-product-caption"
3
  VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
4
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
5
  PROJECT_NAME = "Product Photography"
scripts/load_pipeline.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import MODEL_NAME,ADAPTER_NAME
2
+ import torch
3
+ from diffusers import DiffusionPipeline
4
+ from wandb.integration.diffusers import autolog
5
+ from config import PROJECT_NAME
6
+ autolog(init=dict(project=PROJECT_NAME))
7
+
8
+
9
+ def load_pipeline(model_name, adapter_name):
10
+ pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to(
11
+ "cuda"
12
+ )
13
+ pipe.load_lora_weights(adapter_name)
14
+ pipe.unet.to(memory_format=torch.channels_last)
15
+ pipe.vae.to(memory_format=torch.channels_last)
16
+ pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
17
+ pipe.vae.decode = torch.compile(
18
+ pipe.vae.decode, mode="max-autotune", fullgraph=True
19
+ )
20
+ pipe.fuse_qkv_projections()
21
+
22
+ return pipe
23
+
24
+ loaded_pipeline = load_pipeline(MODEL_NAME, ADAPTER_NAME)
25
+ images = loaded_pipeline('toaster', num_inference_steps=30).images[0]
scripts/wandb/debug.log CHANGED
@@ -1 +1 @@
1
- run-20240430_104236-lcgqwfyr/logs/debug.log
 
1
+ run-20240507_154024-2j1bt71e/logs/debug.log