Spaces:
Runtime error
Runtime error
import torch | |
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer | |
class SDXLTextEncoder(torch.nn.Module): | |
"""Wrapper around HuggingFace text encoders for SDXL. | |
Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one. | |
Args: | |
model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. | |
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. | |
""" | |
def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True, torch_dtype=None): | |
super().__init__() | |
if torch_dtype is None: | |
torch_dtype = torch.float16 if encode_latents_in_fp16 else None | |
self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch_dtype) | |
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, | |
subfolder='text_encoder_2', | |
torch_dtype=torch_dtype) | |
def device(self): | |
return self.text_encoder.device | |
def forward(self, tokenized_text): | |
# first text encoder | |
conditioning = self.text_encoder(tokenized_text[0], output_hidden_states=True).hidden_states[-2] | |
# second text encoder | |
text_encoder_2_out = self.text_encoder_2(tokenized_text[1], output_hidden_states=True) | |
pooled_conditioning = text_encoder_2_out[0] # (batch_size, 1280) | |
conditioning_2 = text_encoder_2_out.hidden_states[-2] # (batch_size, 77, 1280) | |
conditioning = torch.concat([conditioning, conditioning_2], dim=-1) | |
return conditioning, pooled_conditioning | |
class SDXLTokenizer: | |
"""Wrapper around HuggingFace tokenizers for SDXL. | |
Tokenizes prompt with two tokenizers and returns the joined output. | |
Args: | |
model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. | |
""" | |
def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): | |
self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') | |
self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') | |
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): | |
tokenized_output = self.tokenizer( | |
prompt, | |
padding=padding, | |
max_length=self.tokenizer.model_max_length if max_length is None else max_length, | |
truncation=truncation, | |
return_tensors=return_tensors) | |
tokenized_output_2 = self.tokenizer_2( | |
prompt, | |
padding=padding, | |
max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, | |
truncation=truncation, | |
return_tensors=return_tensors) | |
# Add second tokenizer output to first tokenizer | |
for key in tokenized_output.keys(): | |
tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]] | |
return tokenized_output | |