Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import T5Tokenizer | |
class T5TextConditionProcessor: | |
def __init__(self, tokens_length, processor_path): | |
self.tokens_length = tokens_length | |
self.processor = T5Tokenizer.from_pretrained(processor_path) | |
def encode(self, text=None, negative_text=None): | |
encoded = self.processor(text, max_length=self.tokens_length, truncation=True) | |
pad_length = self.tokens_length - len(encoded['input_ids']) | |
input_ids = encoded['input_ids'] + [self.processor.pad_token_id] * pad_length | |
attention_mask = encoded['attention_mask'] + [0] * pad_length | |
condition_model_input = { | |
'input_ids': torch.tensor(input_ids, dtype=torch.long), | |
'attention_mask': torch.tensor(attention_mask, dtype=torch.long) | |
} | |
if negative_text is not None: | |
negative_encoded = self.processor(negative_text, max_length=self.tokens_length, truncation=True) | |
negative_input_ids = negative_encoded['input_ids'][:len(encoded['input_ids'])] | |
negative_input_ids[-1] = self.processor.eos_token_id | |
negative_pad_length = self.tokens_length - len(negative_input_ids) | |
negative_input_ids = negative_input_ids + [self.processor.pad_token_id] * negative_pad_length | |
negative_attention_mask = encoded['attention_mask'] + [0] * pad_length | |
negative_condition_model_input = { | |
'input_ids': torch.tensor(negative_input_ids, dtype=torch.long), | |
'attention_mask': torch.tensor(negative_attention_mask, dtype=torch.long) | |
} | |
else: | |
negative_condition_model_input = None | |
return condition_model_input, negative_condition_model_input | |