VikramSingh178 commited on
Commit
ee3efed
1 Parent(s): 7723596

Update sdxl_text_to_image.py and clear_memory.py

Browse files
api/routers/sdxl_text_to_image.py CHANGED
@@ -1,27 +1,13 @@
1
- <<<<<<< HEAD:api/routers/sdxl_text_to_image.py
2
- =======
3
- import sys
4
- sys.path.append("../scripts") # Path of the scripts directory
5
- >>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
6
  import config
7
  from fastapi import APIRouter, HTTPException
8
  from typing import List
9
  from diffusers import DiffusionPipeline
10
  import torch
11
  from functools import lru_cache
12
- <<<<<<< HEAD:api/routers/sdxl_text_to_image.py
13
  from scripts.api_utils import accelerator
14
  from models.sdxl_input import InputFormat
15
  from async_batcher.batcher import AsyncBatcher
16
  from scripts.api_utils import pil_to_b64_json, pil_to_s3_json
17
- =======
18
- from s3_manager import S3ManagerService
19
- from PIL import Image
20
- import io
21
-
22
-
23
-
24
- >>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
25
  torch._inductor.config.conv_1x1_as_mm = True
26
  torch._inductor.config.coordinate_descent_tuning = True
27
  torch._inductor.config.epilogue_fusion = False
@@ -34,38 +20,10 @@ device = accelerator()
34
  router = APIRouter()
35
 
36
 
37
- <<<<<<< HEAD:api/routers/sdxl_text_to_image.py
38
- =======
39
-
40
-
41
-
42
- def pil_to_b64_json(image):
43
- image_id = str(uuid.uuid4())
44
- buffered = BytesIO()
45
- image.save(buffered, format="PNG")
46
- b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
47
- return {"image_id": image_id, "b64_image": b64_image}
48
-
49
-
50
- def pil_to_s3_json(image: Image.Image,file_name) -> str:
51
- image_id = str(uuid.uuid4())
52
- s3_uploader = S3ManagerService()
53
- image_bytes = io.BytesIO()
54
- image.save(image_bytes, format="PNG")
55
- image_bytes.seek(0)
56
-
57
- unique_file_name = s3_uploader.generate_unique_file_name(file_name)
58
- s3_uploader.upload_file(image_bytes, unique_file_name)
59
- signed_url = s3_uploader.generate_signed_url(
60
- unique_file_name, exp=43200
61
- ) # 12 hours
62
- return {"image_id": image_id, "url": signed_url}
63
- >>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
64
 
65
 
66
  # Load the diffusion pipeline
67
  @lru_cache(maxsize=1)
68
- <<<<<<< HEAD:api/routers/sdxl_text_to_image.py
69
  def load_pipeline(model_name, adapter_name,enable_compile:bool):
70
  """
71
  Load the diffusion pipeline with the specified model and adapter names.
@@ -85,19 +43,6 @@ def load_pipeline(model_name, adapter_name,enable_compile:bool):
85
  if enable_compile is True:
86
  pipe.unet = torch.compile(pipe.unet, mode="max-autotune")
87
  pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune")
88
- =======
89
- def load_pipeline(model_name, adapter_name,adapter_name_2):
90
- pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype= torch.bfloat16 ).to(
91
- "cuda"
92
- )
93
- pipe.load_lora_weights(adapter_name)
94
- pipe.load_lora_weights(adapter_name_2)
95
- pipe.set_adapters([adapter_name, adapter_name_2], adapter_weights=[0.7, 0.5])
96
- pipe.fuse_lora()
97
- pipe.unload_lora_weights()
98
- pipe.unet.to(memory_format=torch.channels_last)
99
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead")
100
- >>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
101
  pipe.fuse_qkv_projections()
102
  return pipe
103
 
 
 
 
 
 
 
1
  import config
2
  from fastapi import APIRouter, HTTPException
3
  from typing import List
4
  from diffusers import DiffusionPipeline
5
  import torch
6
  from functools import lru_cache
 
7
  from scripts.api_utils import accelerator
8
  from models.sdxl_input import InputFormat
9
  from async_batcher.batcher import AsyncBatcher
10
  from scripts.api_utils import pil_to_b64_json, pil_to_s3_json
 
 
 
 
 
 
 
 
11
  torch._inductor.config.conv_1x1_as_mm = True
12
  torch._inductor.config.coordinate_descent_tuning = True
13
  torch._inductor.config.epilogue_fusion = False
 
20
  router = APIRouter()
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # Load the diffusion pipeline
26
  @lru_cache(maxsize=1)
 
27
  def load_pipeline(model_name, adapter_name,enable_compile:bool):
28
  """
29
  Load the diffusion pipeline with the specified model and adapter names.
 
43
  if enable_compile is True:
44
  pipe.unet = torch.compile(pipe.unet, mode="max-autotune")
45
  pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune")
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  pipe.fuse_qkv_projections()
47
  return pipe
48
 
scripts/clear_memory.py DELETED
@@ -1,18 +0,0 @@
1
- import gc
2
- import torch
3
- from logger import rich_logger as l
4
-
5
- def clear_memory():
6
- """
7
- Clears the memory by collecting garbage and emptying the CUDA cache.
8
-
9
- This function is useful when dealing with memory-intensive operations in Python, especially when using libraries like PyTorch.
10
-
11
- Note:
12
- This function requires the `gc` and `torch` modules to be imported.
13
-
14
- """
15
- gc.collect()
16
- torch.cuda.empty_cache()
17
- l.info("Memory Cleared")
18
-