levihsu commited on
Commit
7bcec6f
β€’
1 Parent(s): e629020

Update ootd/inference_ootd_hd.py

Browse files
Files changed (1) hide show
  1. ootd/inference_ootd_hd.py +21 -24
ootd/inference_ootd_hd.py CHANGED
@@ -29,34 +29,35 @@ VAE_PATH = "levihsu/ootd"
29
  UNET_PATH = "levihsu/ootd"
30
  MODEL_PATH = "levihsu/ootd"
31
 
32
- # ootd_hd/checkpoint-36000/
33
-
34
  class OOTDiffusionHD:
35
 
36
  def __init__(self, gpu_id):
37
  self.gpu_id = 'cuda:' + str(gpu_id)
38
 
39
- # vae = AutoencoderKL.from_pretrained(
40
- # VAE_PATH,
41
- # subfolder="vae",
42
- # torch_dtype=torch.float16,
43
- # )
44
-
45
- # unet_garm = UNetGarm2DConditionModel.from_pretrained(
46
- # UNET_PATH,
47
- # subfolder="unet_garm",
48
- # torch_dtype=torch.float16,
49
- # use_safetensors=True,
50
- # )
51
- # unet_vton = UNetVton2DConditionModel.from_pretrained(
52
- # UNET_PATH,
53
- # subfolder="unet_vton",
54
- # torch_dtype=torch.float16,
55
- # use_safetensors=True,
56
- # )
57
 
58
  self.pipe = OotdPipeline.from_pretrained(
59
  MODEL_PATH,
 
 
 
60
  torch_dtype=torch.float16,
61
  variant="fp16",
62
  use_safetensors=True,
@@ -64,10 +65,6 @@ class OOTDiffusionHD:
64
  requires_safety_checker=False,
65
  ).to(self.gpu_id)
66
 
67
- # vae=vae,
68
- # unet_garm=unet_garm,
69
- # unet_vton=unet_vton,
70
-
71
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
72
 
73
  self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
 
29
  UNET_PATH = "levihsu/ootd"
30
  MODEL_PATH = "levihsu/ootd"
31
 
 
 
32
  class OOTDiffusionHD:
33
 
34
  def __init__(self, gpu_id):
35
  self.gpu_id = 'cuda:' + str(gpu_id)
36
 
37
+ vae = AutoencoderKL.from_pretrained(
38
+ VAE_PATH,
39
+ subfolder="vae",
40
+ torch_dtype=torch.float16,
41
+ )
42
+
43
+ unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
+ UNET_PATH,
45
+ subfolder="ootd_hd/checkpoint-36000/unet_garm",
46
+ torch_dtype=torch.float16,
47
+ use_safetensors=True,
48
+ )
49
+ unet_vton = UNetVton2DConditionModel.from_pretrained(
50
+ UNET_PATH,
51
+ subfolder="ootd_hd/checkpoint-36000/unet_vton",
52
+ torch_dtype=torch.float16,
53
+ use_safetensors=True,
54
+ )
55
 
56
  self.pipe = OotdPipeline.from_pretrained(
57
  MODEL_PATH,
58
+ vae=vae,
59
+ unet_garm=unet_garm,
60
+ unet_vton=unet_vton,
61
  torch_dtype=torch.float16,
62
  variant="fp16",
63
  use_safetensors=True,
 
65
  requires_safety_checker=False,
66
  ).to(self.gpu_id)
67
 
 
 
 
 
68
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
 
70
  self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)