Spaces:
Running
Running
VikramSingh178
commited on
Commit
•
ee3efed
1
Parent(s):
7723596
Update sdxl_text_to_image.py and clear_memory.py
Browse files- api/routers/sdxl_text_to_image.py +0 -55
- scripts/clear_memory.py +0 -18
api/routers/sdxl_text_to_image.py
CHANGED
@@ -1,27 +1,13 @@
|
|
1 |
-
<<<<<<< HEAD:api/routers/sdxl_text_to_image.py
|
2 |
-
=======
|
3 |
-
import sys
|
4 |
-
sys.path.append("../scripts") # Path of the scripts directory
|
5 |
-
>>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
|
6 |
import config
|
7 |
from fastapi import APIRouter, HTTPException
|
8 |
from typing import List
|
9 |
from diffusers import DiffusionPipeline
|
10 |
import torch
|
11 |
from functools import lru_cache
|
12 |
-
<<<<<<< HEAD:api/routers/sdxl_text_to_image.py
|
13 |
from scripts.api_utils import accelerator
|
14 |
from models.sdxl_input import InputFormat
|
15 |
from async_batcher.batcher import AsyncBatcher
|
16 |
from scripts.api_utils import pil_to_b64_json, pil_to_s3_json
|
17 |
-
=======
|
18 |
-
from s3_manager import S3ManagerService
|
19 |
-
from PIL import Image
|
20 |
-
import io
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
>>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
|
25 |
torch._inductor.config.conv_1x1_as_mm = True
|
26 |
torch._inductor.config.coordinate_descent_tuning = True
|
27 |
torch._inductor.config.epilogue_fusion = False
|
@@ -34,38 +20,10 @@ device = accelerator()
|
|
34 |
router = APIRouter()
|
35 |
|
36 |
|
37 |
-
<<<<<<< HEAD:api/routers/sdxl_text_to_image.py
|
38 |
-
=======
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
def pil_to_b64_json(image):
|
43 |
-
image_id = str(uuid.uuid4())
|
44 |
-
buffered = BytesIO()
|
45 |
-
image.save(buffered, format="PNG")
|
46 |
-
b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
47 |
-
return {"image_id": image_id, "b64_image": b64_image}
|
48 |
-
|
49 |
-
|
50 |
-
def pil_to_s3_json(image: Image.Image,file_name) -> str:
|
51 |
-
image_id = str(uuid.uuid4())
|
52 |
-
s3_uploader = S3ManagerService()
|
53 |
-
image_bytes = io.BytesIO()
|
54 |
-
image.save(image_bytes, format="PNG")
|
55 |
-
image_bytes.seek(0)
|
56 |
-
|
57 |
-
unique_file_name = s3_uploader.generate_unique_file_name(file_name)
|
58 |
-
s3_uploader.upload_file(image_bytes, unique_file_name)
|
59 |
-
signed_url = s3_uploader.generate_signed_url(
|
60 |
-
unique_file_name, exp=43200
|
61 |
-
) # 12 hours
|
62 |
-
return {"image_id": image_id, "url": signed_url}
|
63 |
-
>>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
|
64 |
|
65 |
|
66 |
# Load the diffusion pipeline
|
67 |
@lru_cache(maxsize=1)
|
68 |
-
<<<<<<< HEAD:api/routers/sdxl_text_to_image.py
|
69 |
def load_pipeline(model_name, adapter_name,enable_compile:bool):
|
70 |
"""
|
71 |
Load the diffusion pipeline with the specified model and adapter names.
|
@@ -85,19 +43,6 @@ def load_pipeline(model_name, adapter_name,enable_compile:bool):
|
|
85 |
if enable_compile is True:
|
86 |
pipe.unet = torch.compile(pipe.unet, mode="max-autotune")
|
87 |
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune")
|
88 |
-
=======
|
89 |
-
def load_pipeline(model_name, adapter_name,adapter_name_2):
|
90 |
-
pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype= torch.bfloat16 ).to(
|
91 |
-
"cuda"
|
92 |
-
)
|
93 |
-
pipe.load_lora_weights(adapter_name)
|
94 |
-
pipe.load_lora_weights(adapter_name_2)
|
95 |
-
pipe.set_adapters([adapter_name, adapter_name_2], adapter_weights=[0.7, 0.5])
|
96 |
-
pipe.fuse_lora()
|
97 |
-
pipe.unload_lora_weights()
|
98 |
-
pipe.unet.to(memory_format=torch.channels_last)
|
99 |
-
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead")
|
100 |
-
>>>>>>> 846a4a3 (commit):product_diffusion_api/routers/sdxl_text_to_image.py
|
101 |
pipe.fuse_qkv_projections()
|
102 |
return pipe
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import config
|
2 |
from fastapi import APIRouter, HTTPException
|
3 |
from typing import List
|
4 |
from diffusers import DiffusionPipeline
|
5 |
import torch
|
6 |
from functools import lru_cache
|
|
|
7 |
from scripts.api_utils import accelerator
|
8 |
from models.sdxl_input import InputFormat
|
9 |
from async_batcher.batcher import AsyncBatcher
|
10 |
from scripts.api_utils import pil_to_b64_json, pil_to_s3_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
torch._inductor.config.conv_1x1_as_mm = True
|
12 |
torch._inductor.config.coordinate_descent_tuning = True
|
13 |
torch._inductor.config.epilogue_fusion = False
|
|
|
20 |
router = APIRouter()
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
# Load the diffusion pipeline
|
26 |
@lru_cache(maxsize=1)
|
|
|
27 |
def load_pipeline(model_name, adapter_name,enable_compile:bool):
|
28 |
"""
|
29 |
Load the diffusion pipeline with the specified model and adapter names.
|
|
|
43 |
if enable_compile is True:
|
44 |
pipe.unet = torch.compile(pipe.unet, mode="max-autotune")
|
45 |
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
pipe.fuse_qkv_projections()
|
47 |
return pipe
|
48 |
|
scripts/clear_memory.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import torch
|
3 |
-
from logger import rich_logger as l
|
4 |
-
|
5 |
-
def clear_memory():
|
6 |
-
"""
|
7 |
-
Clears the memory by collecting garbage and emptying the CUDA cache.
|
8 |
-
|
9 |
-
This function is useful when dealing with memory-intensive operations in Python, especially when using libraries like PyTorch.
|
10 |
-
|
11 |
-
Note:
|
12 |
-
This function requires the `gc` and `torch` modules to be imported.
|
13 |
-
|
14 |
-
"""
|
15 |
-
gc.collect()
|
16 |
-
torch.cuda.empty_cache()
|
17 |
-
l.info("Memory Cleared")
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|