File size: 1,743 Bytes
d90acf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from transformers import T5Tokenizer


class T5TextConditionProcessor:

    def __init__(self, tokens_length, processor_path):
        self.tokens_length = tokens_length
        self.processor = T5Tokenizer.from_pretrained(processor_path)

    def encode(self, text=None, negative_text=None):
        encoded = self.processor(text, max_length=self.tokens_length, truncation=True)
        pad_length = self.tokens_length - len(encoded['input_ids'])
        input_ids = encoded['input_ids'] + [self.processor.pad_token_id] * pad_length
        attention_mask = encoded['attention_mask'] + [0] * pad_length
        condition_model_input = {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
        }

        if negative_text is not None:
            negative_encoded = self.processor(negative_text, max_length=self.tokens_length, truncation=True)
            negative_input_ids = negative_encoded['input_ids'][:len(encoded['input_ids'])]
            negative_input_ids[-1] = self.processor.eos_token_id
            negative_pad_length = self.tokens_length - len(negative_input_ids)
            negative_input_ids = negative_input_ids + [self.processor.pad_token_id] * negative_pad_length
            negative_attention_mask = encoded['attention_mask'] + [0] * pad_length
            negative_condition_model_input = {
                'input_ids': torch.tensor(negative_input_ids, dtype=torch.long),
                'attention_mask': torch.tensor(negative_attention_mask, dtype=torch.long)
            }
        else:
            negative_condition_model_input = None
        return condition_model_input, negative_condition_model_input