Pyramid_Flow / modeling_text_encoder.py
xxxpo13's picture
Update modeling_text_encoder.py
1b12be6 verified
raw
history blame
6.68 kB
import torch
import torch.nn as nn
import os
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from typing import Union, List, Optional
class SD3TextEncoderWithMask(nn.Module):
def __init__(self, model_path, torch_dtype):
super().__init__()
# Define the devices for each GPU
self.device_0 = torch.device('cuda:0') # GPU 0 for text encoder
self.device_1 = torch.device('cuda:1') # GPU 1 for other tasks
# Tokenizers for CLIP and T5
self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
# Lazy loading of models
self.text_encoder = None
self.text_encoder_2 = None
self.text_encoder_3 = None
self.model_path = model_path
self.torch_dtype = torch_dtype
self.tokenizer_max_length = self.tokenizer.model_max_length
# Freeze parameters to avoid training overhead
self._freeze()
def _freeze(self):
""" Freeze all model parameters to avoid training overhead. """
for param in self.parameters():
param.requires_grad = False
def _load_models_if_needed(self):
""" Load models only if they haven't been loaded already. """
if self.text_encoder is None:
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
).to(self.device_0) # Move to GPU 0
if self.text_encoder_2 is None:
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
).to(self.device_0) # Move to GPU 0
if self.text_encoder_3 is None:
self.text_encoder_3 = T5EncoderModel.from_pretrained(
os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
).to(self.device_0) # Move to GPU 0
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
max_sequence_length: int = 128,
):
""" Get embeddings from T5 model. """
self._load_models_if_needed() # Lazy loading
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device)
# Duplicate embeddings for each image generation
batch_size = len(prompt)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1).repeat(num_images_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_model_index: int = 0,
):
""" Get embeddings from CLIP model. """
self._load_models_if_needed() # Lazy loading
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = clip_tokenizers[clip_model_index]
text_encoder = clip_text_encoders[clip_model_index]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
# Duplicate embeddings for each image generation
batch_size = len(prompt)
pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1)
return pooled_prompt_embeds
def encode_prompt(self,
prompt,
num_images_per_prompt=1,
device=None
):
""" Encode the prompt using both CLIP and T5 models. """
prompt = [prompt] if isinstance(prompt, str) else prompt
# Get embeddings from both CLIP models (on GPU 0)
pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=0)
pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=1)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
# Get T5 embeddings (on GPU 0)
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
def forward(self, input_prompts):
""" Forward pass for encoding prompts. """
with torch.no_grad():
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
# Example code for using GPU 1 for other parts of the model
class OtherModel(nn.Module):
def __init__(self):
super(OtherModel, self).__init__()
# Define your model layers
self.fc = nn.Linear(512, 512).to('cuda:1') # Example layer on GPU 1
def forward(self, x):
return self.fc(x)
# In the main script or generation process, use GPU 1 for other tasks
other_model = OtherModel().to('cuda:1') # Load on GPU 1
input_data = torch.randn(64, 512).to('cuda:1') # Move input data to GPU 1
# Perform forward pass on GPU 1
output = other_model(input_data)
print(output)