xxxpo13 commited on
Commit
1b12be6
·
verified ·
1 Parent(s): d378367

Update modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +3 -3
modeling_text_encoder.py CHANGED
@@ -43,17 +43,17 @@ class SD3TextEncoderWithMask(nn.Module):
43
  if self.text_encoder is None:
44
  self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
45
  os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
46
- ).to(self.device_1) # Move to GPU 0
47
 
48
  if self.text_encoder_2 is None:
49
  self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
50
  os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
51
- ).to(self.device_1) # Move to GPU 0
52
 
53
  if self.text_encoder_3 is None:
54
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
55
  os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
56
- ).to(self.device_1) # Move to GPU 0
57
 
58
  def _get_t5_prompt_embeds(
59
  self,
 
43
  if self.text_encoder is None:
44
  self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
45
  os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
46
+ ).to(self.device_0) # Move to GPU 0
47
 
48
  if self.text_encoder_2 is None:
49
  self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
50
  os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
51
+ ).to(self.device_0) # Move to GPU 0
52
 
53
  if self.text_encoder_3 is None:
54
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
55
  os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
56
+ ).to(self.device_0) # Move to GPU 0
57
 
58
  def _get_t5_prompt_embeds(
59
  self,