ysharma HF staff commited on
Commit
bca52a3
·
1 Parent(s): 975f6d0
Files changed (1) hide show
  1. modules.py +131 -0
modules.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+
5
+ from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
6
+
7
+
8
+ class AbstractEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def encode(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+
16
+
17
+ class ClassEmbedder(nn.Module):
18
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
19
+ super().__init__()
20
+ self.key = key
21
+ self.embedding = nn.Embedding(n_classes, embed_dim)
22
+
23
+ def forward(self, batch, key=None):
24
+ if key is None:
25
+ key = self.key
26
+ # this is for use in crossattn
27
+ c = batch[key][:, None]
28
+ c = self.embedding(c)
29
+ return c
30
+
31
+
32
+ class TransformerEmbedder(AbstractEncoder):
33
+ """Some transformer encoder layers"""
34
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
35
+ super().__init__()
36
+ self.device = device
37
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
38
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
39
+
40
+ def forward(self, tokens):
41
+ tokens = tokens.to(self.device) # meh
42
+ z = self.transformer(tokens, return_embeddings=True)
43
+ return z
44
+
45
+ def encode(self, x):
46
+ return self(x)
47
+
48
+
49
+ class BERTTokenizer(AbstractEncoder):
50
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
51
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
52
+ super().__init__()
53
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
54
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
55
+ self.device = device
56
+ self.vq_interface = vq_interface
57
+ self.max_length = max_length
58
+
59
+ def forward(self, text):
60
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
61
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
62
+ tokens = batch_encoding["input_ids"].to(self.device)
63
+ return tokens
64
+
65
+ @torch.no_grad()
66
+ def encode(self, text):
67
+ tokens = self(text)
68
+ if not self.vq_interface:
69
+ return tokens
70
+ return None, None, [None, None, tokens]
71
+
72
+ def decode(self, text):
73
+ return text
74
+
75
+
76
+ class BERTEmbedder(AbstractEncoder):
77
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
78
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
79
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
80
+ super().__init__()
81
+ self.use_tknz_fn = use_tokenizer
82
+ if self.use_tknz_fn:
83
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
84
+ self.device = device
85
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
86
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
87
+ emb_dropout=embedding_dropout)
88
+
89
+ def forward(self, text):
90
+ if self.use_tknz_fn:
91
+ tokens = self.tknz_fn(text)#.to(self.device)
92
+ else:
93
+ tokens = text
94
+ z = self.transformer(tokens, return_embeddings=True)
95
+ return z
96
+
97
+ def encode(self, text):
98
+ # output of length 77
99
+ return self(text)
100
+
101
+
102
+ class SpatialRescaler(nn.Module):
103
+ def __init__(self,
104
+ n_stages=1,
105
+ method='bilinear',
106
+ multiplier=0.5,
107
+ in_channels=3,
108
+ out_channels=None,
109
+ bias=False):
110
+ super().__init__()
111
+ self.n_stages = n_stages
112
+ assert self.n_stages >= 0
113
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
114
+ self.multiplier = multiplier
115
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
116
+ self.remap_output = out_channels is not None
117
+ if self.remap_output:
118
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
119
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
120
+
121
+ def forward(self,x):
122
+ for stage in range(self.n_stages):
123
+ x = self.interpolator(x, scale_factor=self.multiplier)
124
+
125
+
126
+ if self.remap_output:
127
+ x = self.channel_mapper(x)
128
+ return x
129
+
130
+ def encode(self, x):
131
+ return self(x)