|
import torch |
|
import torch.nn as nn |
|
import os |
|
from transformers import ( |
|
CLIPTextModelWithProjection, |
|
CLIPTokenizer, |
|
T5EncoderModel, |
|
T5TokenizerFast, |
|
) |
|
from typing import Union, List, Optional |
|
|
|
class SD3TextEncoderWithMask(nn.Module): |
|
def __init__(self, model_path, torch_dtype): |
|
super().__init__() |
|
|
|
|
|
self.device_0 = torch.device('cuda:0') |
|
self.device_1 = torch.device('cuda:1') |
|
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer')) |
|
self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2')) |
|
self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3')) |
|
|
|
|
|
self.text_encoder = None |
|
self.text_encoder_2 = None |
|
self.text_encoder_3 = None |
|
self.model_path = model_path |
|
self.torch_dtype = torch_dtype |
|
self.tokenizer_max_length = self.tokenizer.model_max_length |
|
|
|
|
|
self._freeze() |
|
|
|
def _freeze(self): |
|
""" Freeze all model parameters to avoid training overhead. """ |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def _load_models_if_needed(self): |
|
""" Load models only if they haven't been loaded already. """ |
|
if self.text_encoder is None: |
|
self.text_encoder = CLIPTextModelWithProjection.from_pretrained( |
|
os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype |
|
).to(self.device_0) |
|
|
|
if self.text_encoder_2 is None: |
|
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( |
|
os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype |
|
).to(self.device_0) |
|
|
|
if self.text_encoder_3 is None: |
|
self.text_encoder_3 = T5EncoderModel.from_pretrained( |
|
os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype |
|
).to(self.device_0) |
|
|
|
def _get_t5_prompt_embeds( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
num_images_per_prompt: int = 1, |
|
device: Optional[torch.device] = None, |
|
max_sequence_length: int = 128, |
|
): |
|
""" Get embeddings from T5 model. """ |
|
self._load_models_if_needed() |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
text_inputs = self.tokenizer_3( |
|
prompt, |
|
padding="max_length", |
|
max_length=max_sequence_length, |
|
truncation=True, |
|
add_special_tokens=True, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
prompt_attention_mask = text_inputs.attention_mask.to(device) |
|
prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0] |
|
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device) |
|
|
|
|
|
batch_size = len(prompt) |
|
_, seq_len, _ = prompt_embeds.shape |
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, seq_len, -1) |
|
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1).repeat(num_images_per_prompt, 1) |
|
|
|
return prompt_embeds, prompt_attention_mask |
|
|
|
def _get_clip_prompt_embeds( |
|
self, |
|
prompt: Union[str, List[str]], |
|
num_images_per_prompt: int = 1, |
|
device: Optional[torch.device] = None, |
|
clip_model_index: int = 0, |
|
): |
|
""" Get embeddings from CLIP model. """ |
|
self._load_models_if_needed() |
|
|
|
clip_tokenizers = [self.tokenizer, self.tokenizer_2] |
|
clip_text_encoders = [self.text_encoder, self.text_encoder_2] |
|
|
|
tokenizer = clip_tokenizers[clip_model_index] |
|
text_encoder = clip_text_encoders[clip_model_index] |
|
|
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0] |
|
|
|
|
|
batch_size = len(prompt) |
|
pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1) |
|
|
|
return pooled_prompt_embeds |
|
|
|
def encode_prompt(self, |
|
prompt, |
|
num_images_per_prompt=1, |
|
device=None |
|
): |
|
""" Encode the prompt using both CLIP and T5 models. """ |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
|
|
pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=0) |
|
pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=1) |
|
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) |
|
|
|
|
|
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0) |
|
|
|
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds |
|
|
|
def forward(self, input_prompts): |
|
""" Forward pass for encoding prompts. """ |
|
with torch.no_grad(): |
|
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts) |
|
|
|
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds |
|
|
|
|
|
class OtherModel(nn.Module): |
|
def __init__(self): |
|
super(OtherModel, self).__init__() |
|
|
|
self.fc = nn.Linear(512, 512).to('cuda:1') |
|
|
|
def forward(self, x): |
|
return self.fc(x) |
|
|
|
|
|
other_model = OtherModel().to('cuda:1') |
|
input_data = torch.randn(64, 512).to('cuda:1') |
|
|
|
|
|
output = other_model(input_data) |
|
print(output) |
|
|