Update modeling_text_encoder.py
Browse files- modeling_text_encoder.py +3 -3
modeling_text_encoder.py
CHANGED
@@ -43,17 +43,17 @@ class SD3TextEncoderWithMask(nn.Module):
|
|
43 |
if self.text_encoder is None:
|
44 |
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
45 |
os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
|
46 |
-
).to(self.
|
47 |
|
48 |
if self.text_encoder_2 is None:
|
49 |
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
|
50 |
os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
|
51 |
-
).to(self.
|
52 |
|
53 |
if self.text_encoder_3 is None:
|
54 |
self.text_encoder_3 = T5EncoderModel.from_pretrained(
|
55 |
os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
|
56 |
-
).to(self.
|
57 |
|
58 |
def _get_t5_prompt_embeds(
|
59 |
self,
|
|
|
43 |
if self.text_encoder is None:
|
44 |
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
45 |
os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
|
46 |
+
).to(self.device_0) # Move to GPU 0
|
47 |
|
48 |
if self.text_encoder_2 is None:
|
49 |
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
|
50 |
os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
|
51 |
+
).to(self.device_0) # Move to GPU 0
|
52 |
|
53 |
if self.text_encoder_3 is None:
|
54 |
self.text_encoder_3 = T5EncoderModel.from_pretrained(
|
55 |
os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
|
56 |
+
).to(self.device_0) # Move to GPU 0
|
57 |
|
58 |
def _get_t5_prompt_embeds(
|
59 |
self,
|