Spaces:
Running
on
Zero
Running
on
Zero
parokshsaxena
commited on
Commit
Β·
6b06100
1
Parent(s):
b0054c9
moving enhanced garment net initialization to app.py
Browse files- app.py +8 -1
- src/tryon_pipeline.py +3 -3
app.py
CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
|
|
6 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
7 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
8 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
|
|
9 |
from transformers import (
|
10 |
CLIPImageProcessor,
|
11 |
CLIPVisionModelWithProjection,
|
@@ -51,6 +52,9 @@ unet = UNet2DConditionModel.from_pretrained(
|
|
51 |
torch_dtype=torch.float16,
|
52 |
)
|
53 |
unet.requires_grad_(False)
|
|
|
|
|
|
|
54 |
tokenizer_one = AutoTokenizer.from_pretrained(
|
55 |
base_path,
|
56 |
subfolder="tokenizer",
|
@@ -92,6 +96,7 @@ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
|
|
92 |
torch_dtype=torch.float16,
|
93 |
)
|
94 |
|
|
|
95 |
parsing_model = Parsing(0)
|
96 |
openpose_model = OpenPose(0)
|
97 |
|
@@ -122,6 +127,7 @@ pipe = TryonPipeline.from_pretrained(
|
|
122 |
torch_dtype=torch.float16,
|
123 |
)
|
124 |
pipe.unet_encoder = UNet_Encoder
|
|
|
125 |
|
126 |
# Standard size of shein images
|
127 |
#WIDTH = int(4160/5)
|
@@ -152,7 +158,8 @@ def start_tryon(human_img_dict,garm_img,garment_des, background_img, is_checked,
|
|
152 |
openpose_model.preprocessor.body_estimation.model.to(device)
|
153 |
pipe.to(device)
|
154 |
pipe.unet_encoder.to(device)
|
155 |
-
|
|
|
156 |
human_img_orig = human_img_dict["background"].convert("RGB") # ImageEditor
|
157 |
#human_img_orig = human_img_dict.convert("RGB") # Image
|
158 |
|
|
|
6 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
7 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
8 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
9 |
+
from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
|
10 |
from transformers import (
|
11 |
CLIPImageProcessor,
|
12 |
CLIPVisionModelWithProjection,
|
|
|
52 |
torch_dtype=torch.float16,
|
53 |
)
|
54 |
unet.requires_grad_(False)
|
55 |
+
|
56 |
+
enhancedGarmentNet = EnhancedGarmentNetWithTimestep(dtype=torch.float16)
|
57 |
+
|
58 |
tokenizer_one = AutoTokenizer.from_pretrained(
|
59 |
base_path,
|
60 |
subfolder="tokenizer",
|
|
|
96 |
torch_dtype=torch.float16,
|
97 |
)
|
98 |
|
99 |
+
|
100 |
parsing_model = Parsing(0)
|
101 |
openpose_model = OpenPose(0)
|
102 |
|
|
|
127 |
torch_dtype=torch.float16,
|
128 |
)
|
129 |
pipe.unet_encoder = UNet_Encoder
|
130 |
+
pipe.garment_net = enhancedGarmentNet
|
131 |
|
132 |
# Standard size of shein images
|
133 |
#WIDTH = int(4160/5)
|
|
|
158 |
openpose_model.preprocessor.body_estimation.model.to(device)
|
159 |
pipe.to(device)
|
160 |
pipe.unet_encoder.to(device)
|
161 |
+
pipe.garment_net.to(device)
|
162 |
+
|
163 |
human_img_orig = human_img_dict["background"].convert("RGB") # ImageEditor
|
164 |
#human_img_orig = human_img_dict.convert("RGB") # Image
|
165 |
|
src/tryon_pipeline.py
CHANGED
@@ -57,7 +57,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
|
57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
58 |
|
59 |
# Commenting out for now
|
60 |
-
from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
|
61 |
|
62 |
|
63 |
|
@@ -401,8 +401,8 @@ class StableDiffusionXLInpaintPipeline(
|
|
401 |
force_zeros_for_empty_prompt: bool = True,
|
402 |
):
|
403 |
super().__init__()
|
404 |
-
#
|
405 |
-
self.garment_net = EnhancedGarmentNetWithTimestep().to(device=self._execution_device, dtype=self.unet.dtype)
|
406 |
|
407 |
|
408 |
|
|
|
57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
58 |
|
59 |
# Commenting out for now
|
60 |
+
# from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
|
61 |
|
62 |
|
63 |
|
|
|
401 |
force_zeros_for_empty_prompt: bool = True,
|
402 |
):
|
403 |
super().__init__()
|
404 |
+
# This is moved to app.py
|
405 |
+
#self.garment_net = EnhancedGarmentNetWithTimestep().to(device=self._execution_device, dtype=self.unet.dtype)
|
406 |
|
407 |
|
408 |
|