parokshsaxena commited on
Commit
6b06100
Β·
1 Parent(s): b0054c9

moving enhanced garment net initialization to app.py

Browse files
Files changed (2) hide show
  1. app.py +8 -1
  2. src/tryon_pipeline.py +3 -3
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
7
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
8
  from src.unet_hacked_tryon import UNet2DConditionModel
 
9
  from transformers import (
10
  CLIPImageProcessor,
11
  CLIPVisionModelWithProjection,
@@ -51,6 +52,9 @@ unet = UNet2DConditionModel.from_pretrained(
51
  torch_dtype=torch.float16,
52
  )
53
  unet.requires_grad_(False)
 
 
 
54
  tokenizer_one = AutoTokenizer.from_pretrained(
55
  base_path,
56
  subfolder="tokenizer",
@@ -92,6 +96,7 @@ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
92
  torch_dtype=torch.float16,
93
  )
94
 
 
95
  parsing_model = Parsing(0)
96
  openpose_model = OpenPose(0)
97
 
@@ -122,6 +127,7 @@ pipe = TryonPipeline.from_pretrained(
122
  torch_dtype=torch.float16,
123
  )
124
  pipe.unet_encoder = UNet_Encoder
 
125
 
126
  # Standard size of shein images
127
  #WIDTH = int(4160/5)
@@ -152,7 +158,8 @@ def start_tryon(human_img_dict,garm_img,garment_des, background_img, is_checked,
152
  openpose_model.preprocessor.body_estimation.model.to(device)
153
  pipe.to(device)
154
  pipe.unet_encoder.to(device)
155
-
 
156
  human_img_orig = human_img_dict["background"].convert("RGB") # ImageEditor
157
  #human_img_orig = human_img_dict.convert("RGB") # Image
158
 
 
6
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
7
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
8
  from src.unet_hacked_tryon import UNet2DConditionModel
9
+ from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
10
  from transformers import (
11
  CLIPImageProcessor,
12
  CLIPVisionModelWithProjection,
 
52
  torch_dtype=torch.float16,
53
  )
54
  unet.requires_grad_(False)
55
+
56
+ enhancedGarmentNet = EnhancedGarmentNetWithTimestep(dtype=torch.float16)
57
+
58
  tokenizer_one = AutoTokenizer.from_pretrained(
59
  base_path,
60
  subfolder="tokenizer",
 
96
  torch_dtype=torch.float16,
97
  )
98
 
99
+
100
  parsing_model = Parsing(0)
101
  openpose_model = OpenPose(0)
102
 
 
127
  torch_dtype=torch.float16,
128
  )
129
  pipe.unet_encoder = UNet_Encoder
130
+ pipe.garment_net = enhancedGarmentNet
131
 
132
  # Standard size of shein images
133
  #WIDTH = int(4160/5)
 
158
  openpose_model.preprocessor.body_estimation.model.to(device)
159
  pipe.to(device)
160
  pipe.unet_encoder.to(device)
161
+ pipe.garment_net.to(device)
162
+
163
  human_img_orig = human_img_dict["background"].convert("RGB") # ImageEditor
164
  #human_img_orig = human_img_dict.convert("RGB") # Image
165
 
src/tryon_pipeline.py CHANGED
@@ -57,7 +57,7 @@ from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
59
  # Commenting out for now
60
- from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
61
 
62
 
63
 
@@ -401,8 +401,8 @@ class StableDiffusionXLInpaintPipeline(
401
  force_zeros_for_empty_prompt: bool = True,
402
  ):
403
  super().__init__()
404
- #self.garment_net = EnhancedGarmentNetWithTimestep()
405
- self.garment_net = EnhancedGarmentNetWithTimestep().to(device=self._execution_device, dtype=self.unet.dtype)
406
 
407
 
408
 
 
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
59
  # Commenting out for now
60
+ # from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
61
 
62
 
63
 
 
401
  force_zeros_for_empty_prompt: bool = True,
402
  ):
403
  super().__init__()
404
+ # This is moved to app.py
405
+ #self.garment_net = EnhancedGarmentNetWithTimestep().to(device=self._execution_device, dtype=self.unet.dtype)
406
 
407
 
408