levihsu commited on
Commit
067863e
·
verified ·
1 Parent(s): a0690fd

Update ootd/inference_ootd_dc.py

Browse files
Files changed (1) hide show
  1. ootd/inference_ootd_dc.py +7 -7
ootd/inference_ootd_dc.py CHANGED
@@ -32,7 +32,7 @@ MODEL_PATH = "./checkpoints/ootd"
32
  class OOTDiffusionDC:
33
 
34
  def __init__(self, gpu_id):
35
- self.gpu_id = 'cuda:' + str(gpu_id)
36
 
37
  vae = AutoencoderKL.from_pretrained(
38
  VAE_PATH,
@@ -63,12 +63,12 @@ class OOTDiffusionDC:
63
  use_safetensors=True,
64
  safety_checker=None,
65
  requires_safety_checker=False,
66
- ).to(self.gpu_id)
67
 
68
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
 
70
  self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
72
 
73
  self.tokenizer = CLIPTokenizer.from_pretrained(
74
  MODEL_PATH,
@@ -77,7 +77,7 @@ class OOTDiffusionDC:
77
  self.text_encoder = CLIPTextModel.from_pretrained(
78
  MODEL_PATH,
79
  subfolder="text_encoder",
80
- ).to(self.gpu_id)
81
 
82
 
83
  def tokenize_captions(self, captions, max_length):
@@ -106,14 +106,14 @@ class OOTDiffusionDC:
106
  generator = torch.manual_seed(seed)
107
 
108
  with torch.no_grad():
109
- prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
110
  prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
  prompt_image = prompt_image.unsqueeze(1)
112
  if model_type == 'hd':
113
- prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
114
  prompt_embeds[:, 1:] = prompt_image[:]
115
  elif model_type == 'dc':
116
- prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
117
  prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
  else:
119
  raise ValueError("model_type must be \'hd\' or \'dc\'!")
 
32
  class OOTDiffusionDC:
33
 
34
  def __init__(self, gpu_id):
35
+ # self.gpu_id = 'cuda:' + str(gpu_id)
36
 
37
  vae = AutoencoderKL.from_pretrained(
38
  VAE_PATH,
 
63
  use_safetensors=True,
64
  safety_checker=None,
65
  requires_safety_checker=False,
66
+ )#.to(self.gpu_id)
67
 
68
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
 
70
  self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)#.to(self.gpu_id)
72
 
73
  self.tokenizer = CLIPTokenizer.from_pretrained(
74
  MODEL_PATH,
 
77
  self.text_encoder = CLIPTextModel.from_pretrained(
78
  MODEL_PATH,
79
  subfolder="text_encoder",
80
+ )#.to(self.gpu_id)
81
 
82
 
83
  def tokenize_captions(self, captions, max_length):
 
106
  generator = torch.manual_seed(seed)
107
 
108
  with torch.no_grad():
109
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to('cuda')
110
  prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
  prompt_image = prompt_image.unsqueeze(1)
112
  if model_type == 'hd':
113
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to('cuda'))[0]
114
  prompt_embeds[:, 1:] = prompt_image[:]
115
  elif model_type == 'dc':
116
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to('cuda'))[0]
117
  prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
  else:
119
  raise ValueError("model_type must be \'hd\' or \'dc\'!")