xxxpo13 commited on
Commit
51e9ef1
·
verified ·
1 Parent(s): f02e267

Update modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +8 -8
modeling_text_encoder.py CHANGED
@@ -10,7 +10,7 @@ from transformers import (
10
  from typing import Union, List, Optional
11
 
12
  class SD3TextEncoderWithMask(nn.Module):
13
- def __init__(self, model_path, torch_dtype=torch.float16):
14
  super().__init__()
15
 
16
  # Tokenizers for CLIP and T5
@@ -47,10 +47,10 @@ class SD3TextEncoderWithMask(nn.Module):
47
  ).to("cuda")
48
 
49
  if self.text_encoder_3 is None:
50
- # Load the FP8 T5 model (adjust path if needed for specific FP8 model)
51
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
52
- os.path.join(self.model_path, 'text_encoder_3_fp8'), # Use the FP8 version
53
- torch_dtype=torch.float8 # FP8 precision
54
  ).to("cuda")
55
 
56
  def _get_t5_prompt_embeds(
@@ -75,9 +75,9 @@ class SD3TextEncoderWithMask(nn.Module):
75
  text_input_ids = text_inputs.input_ids.to(device)
76
  prompt_attention_mask = text_inputs.attention_mask.to(device)
77
 
78
- # Use the T5 model to generate embeddings in FP8
79
  prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
80
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype) # Ensure correct dtype
81
 
82
  # Duplicate embeddings for each image generation
83
  batch_size = len(prompt)
@@ -111,7 +111,7 @@ class SD3TextEncoderWithMask(nn.Module):
111
  return_tensors="pt",
112
  )
113
 
114
- text_input_ids = text_inputs.input.ids.to(device)
115
  prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
116
 
117
  # Duplicate embeddings for each image generation
@@ -133,7 +133,7 @@ class SD3TextEncoderWithMask(nn.Module):
133
  pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=1)
134
  pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
135
 
136
- # Get T5 embeddings in FP8
137
  prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device)
138
 
139
  return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
 
10
  from typing import Union, List, Optional
11
 
12
  class SD3TextEncoderWithMask(nn.Module):
13
+ def __init__(self, model_path, torch_dtype=torch.float16): # Default to FP16
14
  super().__init__()
15
 
16
  # Tokenizers for CLIP and T5
 
47
  ).to("cuda")
48
 
49
  if self.text_encoder_3 is None:
50
+ # Load the T5 model in FP16
51
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
52
+ os.path.join(self.model_path, 'text_encoder_3'), # Ensure you're using the correct path
53
+ torch_dtype=self.torch_dtype # Use FP16 precision
54
  ).to("cuda")
55
 
56
  def _get_t5_prompt_embeds(
 
75
  text_input_ids = text_inputs.input_ids.to(device)
76
  prompt_attention_mask = text_inputs.attention_mask.to(device)
77
 
78
+ # Use the T5 model to generate embeddings in FP16
79
  prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
80
+ prompt_embeds = prompt_embeds.to(dtype=self.torch_dtype) # Ensure correct dtype
81
 
82
  # Duplicate embeddings for each image generation
83
  batch_size = len(prompt)
 
111
  return_tensors="pt",
112
  )
113
 
114
+ text_input_ids = text_inputs.input_ids.to(device)
115
  prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
116
 
117
  # Duplicate embeddings for each image generation
 
133
  pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=1)
134
  pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
135
 
136
+ # Get T5 embeddings in FP16
137
  prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device)
138
 
139
  return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds