jiwan-chung commited on
Commit
0bf81ba
1 Parent(s): 989194f
Files changed (9) hide show
  1. .gitignore +3 -0
  2. app.py +43 -0
  3. arguments.py +58 -0
  4. clipcap.py +385 -0
  5. load.py +60 -0
  6. policy.py +219 -0
  7. requirements.txt +19 -0
  8. run.py +173 -0
  9. utils.py +28 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ data
2
+ flagged
3
+ __pycache__
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gdown
4
+ # from PIL import Image
5
+ # from numpy import asarray
6
+
7
+ from run import launch
8
+
9
+ # download
10
+ if not Path('./data').is_dir():
11
+ url = 'https://drive.google.com/drive/folders/1hfHWDn5iXsdjB63E5zdZBAoRLWHQC3LD'
12
+ gdown.download_folder(url, quiet=True, use_cookies=False, output="./data/")
13
+
14
+ # example image from COCO data
15
+ image_urls = {
16
+ '108953': 'https://farm8.staticflickr.com/7160/6484651991_9d1eaa557a_z.jpg'
17
+ }
18
+ images = {}
19
+ for k, url in image_urls.items():
20
+ ext = Path(url).suffix
21
+ output = Path(f"data/images/{k}{ext}")
22
+ if not output.is_file():
23
+ output.parent.mkdir(exist_ok=True)
24
+ gdown.download(url, quiet=True, use_cookies=False, output=str(output))
25
+ images[k] = str(output)
26
+
27
+ '''
28
+ for k, v in images.items():
29
+ with Image.open(v) as image:
30
+ # image = asarray(image)
31
+ images[k] = image
32
+ '''
33
+
34
+
35
+ examples = [[
36
+ v,
37
+ 'My favourite recipe:',
38
+ 20,
39
+ 10,
40
+ False
41
+ ] for v in images.values()]
42
+
43
+ launch(examples)
arguments.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+
10
+ logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser(description='ESPER')
16
+
17
+
18
+ parser.add_argument(
19
+ '--init-model', type=str, default='gpt2', help='language model used for policy.')
20
+ parser.add_argument(
21
+ '--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
22
+ parser.add_argument(
23
+ '--checkpoint', type=str, default='./data/esper_demo/ckpt/gpt2_style', help='checkpoint file path')
24
+
25
+ parser.add_argument(
26
+ '--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
27
+ parser.add_argument(
28
+ '--clipcap_num_layers', type=int, default=1, help='num_layers for the visual mapper')
29
+ parser.add_argument(
30
+ '--use_transformer_mapper', action='store_true', default=False, help='use transformer mapper instead of mlp')
31
+ parser.add_argument(
32
+ '--use_label_prefix', action='store_true', default=False, help='label as prefixes')
33
+ parser.add_argument(
34
+ '--clip_model_type', type=str, default='ViT-B/32', help='clip backbone type')
35
+
36
+ parser.add_argument(
37
+ '--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
38
+ parser.add_argument(
39
+ '--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
40
+ parser.add_argument(
41
+ '--num-gpus', type=int, default=None, help='number of gpus. use all available if none')
42
+ parser.add_argument(
43
+ '--port', type=int, default=None, help="port for the demo server")
44
+
45
+ args = parser.parse_args()
46
+ args.cuda = torch.cuda.is_available()
47
+
48
+ if args.use_label_prefix:
49
+ log.info(f'using label prefix')
50
+ num_gpus = torch.cuda.device_count()
51
+ if args.num_gpus is None:
52
+ args.num_gpus = num_gpus
53
+ else:
54
+ args.num_gpus = min(num_gpus, args.num_gpus)
55
+
56
+ if args.checkpoint is not None:
57
+ args.checkpoint = str(Path(args.checkpoint).resolve())
58
+ return args
clipcap.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import logging
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Tuple, Optional, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM
12
+
13
+
14
+ logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ def load_weights(self, Module, path, name, default_name, prev_name=None, **kwargs):
19
+ hparams = None
20
+ assert isinstance(default_name, str), f'invalid default transformer name: {default_name}'
21
+ model = get_transformer_module(Module, default_name, **kwargs)
22
+ setattr(self, name, model)
23
+ return hparams
24
+
25
+
26
+ def get_transformer_module(Module, default_name, **kwargs):
27
+ if default_name == 'EleutherAI/gpt-j-6B':
28
+ kwargs = {**kwargs, **dict(revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)}
29
+ model = Module.from_pretrained(default_name, **kwargs)
30
+ return model
31
+
32
+
33
+ class MLP(nn.Module):
34
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
35
+ super(MLP, self).__init__()
36
+
37
+ self.divider = math.sqrt(sizes[-1] / sizes[0])
38
+ layers = []
39
+ for i in range(len(sizes) - 1):
40
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
41
+ if i < len(sizes) - 2:
42
+ layers.append(act())
43
+ self.model = nn.Sequential(*layers)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ x = x / self.divider # scaling for the initial stability
47
+ x = self.model(x)
48
+ return x
49
+
50
+
51
+ class MlpTransformer(nn.Module):
52
+ def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=F.relu, dropout=0.):
53
+ super().__init__()
54
+ out_d = out_d if out_d is not None else in_dim
55
+ self.fc1 = nn.Linear(in_dim, h_dim)
56
+ self.act = act
57
+ self.fc2 = nn.Linear(h_dim, out_d)
58
+ self.dropout = nn.Dropout(dropout)
59
+
60
+ def forward(self, x):
61
+ x = self.fc1(x)
62
+ x = self.act(x)
63
+ x = self.dropout(x)
64
+ x = self.fc2(x)
65
+ x = self.dropout(x)
66
+ return x
67
+
68
+
69
+ class MultiHeadAttention(nn.Module):
70
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
71
+ super().__init__()
72
+ self.num_heads = num_heads
73
+ head_dim = dim_self // num_heads
74
+ self.scale = head_dim ** -0.5
75
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
76
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
77
+ self.project = nn.Linear(dim_self, dim_self)
78
+ self.dropout = nn.Dropout(dropout)
79
+
80
+ def forward(self, x, y=None, mask=None):
81
+ y = y if y is not None else x
82
+ b, n, c = x.shape
83
+ _, m, d = y.shape
84
+ # b n h dh
85
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
86
+ # b m 2 h dh
87
+ keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
88
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
89
+ attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
90
+ if mask is not None:
91
+ if mask.dim() == 2:
92
+ mask = mask.unsqueeze(1)
93
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
94
+ attention = attention.softmax(dim=2)
95
+ out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
96
+ out = self.project(out)
97
+ return out, attention
98
+
99
+
100
+ class TransformerLayer(nn.Module):
101
+ def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=F.relu,
102
+ norm_layer: nn.Module = nn.LayerNorm):
103
+ super().__init__()
104
+ self.norm1 = norm_layer(dim_self)
105
+ self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
106
+ self.norm2 = norm_layer(dim_self)
107
+ self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
108
+
109
+ def forward_with_attention(self, x, y=None, mask=None):
110
+ x_, attention = self.attn(self.norm1(x), y, mask)
111
+ x = x + x_
112
+ x = x + self.mlp(self.norm2(x))
113
+ return x, attention
114
+
115
+ def forward(self, x, y=None, mask=None):
116
+ x = x + self.attn(self.norm1(x), y, mask)[0]
117
+ x = x + self.mlp(self.norm2(x))
118
+ return x
119
+
120
+
121
+ class Transformer(nn.Module):
122
+ def forward_with_attention(self, x, y=None, mask=None):
123
+ attentions = []
124
+ for layer in self.layers:
125
+ x, att = layer.forward_with_attention(x, y, mask)
126
+ attentions.append(att)
127
+ return x, attentions
128
+
129
+ def forward(self, x, y=None, mask=None):
130
+ for i, layer in enumerate(self.layers):
131
+ if i % 2 == 0 and self.enc_dec: # cross
132
+ x = layer(x, y)
133
+ elif self.enc_dec: # self
134
+ x = layer(x, x, mask)
135
+ else: # self or cross
136
+ x = layer(x, y, mask)
137
+ return x
138
+
139
+ def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
140
+ mlp_ratio: float = 2., act=F.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
141
+ super(Transformer, self).__init__()
142
+ dim_ref = dim_ref if dim_ref is not None else dim_self
143
+ self.enc_dec = enc_dec
144
+ if enc_dec:
145
+ num_layers = num_layers * 2
146
+ layers = []
147
+ for i in range(num_layers):
148
+ if i % 2 == 0 and enc_dec: # cross
149
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
150
+ elif enc_dec: # self
151
+ layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
152
+ else: # self or cross
153
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
154
+ self.layers = nn.ModuleList(layers)
155
+
156
+
157
+ class TransformerMapper(nn.Module):
158
+ def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int = 10,
159
+ clip_length: int = 10, num_layers: int = 8):
160
+ super(TransformerMapper, self).__init__()
161
+ self.clip_length = clip_length
162
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
163
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
164
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
165
+
166
+ def forward(self, x):
167
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
168
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
169
+ prefix = torch.cat((x, prefix), dim=1)
170
+ out = self.transformer(prefix)[:, self.clip_length:]
171
+ return out
172
+
173
+
174
+ class ClipCap(nn.Module):
175
+ def __init__(self, model_name, device, prefix_length: int = 10, clip_length: int = 40, prefix_size: int = 512,
176
+ num_layers: int = 1, model_path: str = '', fix_gpt: bool = False,
177
+ use_label_prefix: bool = False, label_path: str = '', label_length: int = 10,
178
+ use_transformer_mapper: bool = False, use_ptuning_v2: bool = False,
179
+ dropout: float = 0,
180
+ model_weight: str = '', scalar_output: bool = False):
181
+ super(ClipCap, self).__init__()
182
+
183
+ self.prefix_length = prefix_length
184
+ self.prefix_size = prefix_size
185
+ self.label_length = label_length
186
+ self.scalar_output = scalar_output
187
+ self.num_layers = num_layers
188
+ self.use_transformer_mapper = use_transformer_mapper
189
+ self.use_ptuning_v2 = use_ptuning_v2
190
+
191
+ self.dropout = nn.Dropout(dropout)
192
+
193
+ hparams = load_weights(self, AutoModelForCausalLM, model_weight, 'gpt', model_name,
194
+ prev_name='model')
195
+
196
+ self.device = device
197
+ self.gpt = self.gpt.to(self.device)
198
+
199
+ config = self.gpt.config
200
+ self.match_n_layer = getattr(config, 'n_layer', getattr(config, 'num_layers', None)) # gpt2 vs. gpt_neo
201
+ self.match_n_head = getattr(config, 'n_head', getattr(config, 'num_heads', None))
202
+ self.n_embd = getattr(config, 'n_embd', getattr(config, 'hidden_size', None))
203
+ self.match_n_embd = self.n_embd // self.match_n_head
204
+
205
+ self.clip_project = self.get_mapper()
206
+
207
+ if Path(label_path).is_file():
208
+ with open(label_path) as f:
209
+ labels = json.load(f)
210
+ self.labels = {i: v for v, i in labels.items()}
211
+ if not use_label_prefix:
212
+ log.info("adding label projections")
213
+ self.label_project = nn.Sequential(
214
+ nn.Embedding(len(self.labels), self.prefix_size),
215
+ self.get_mapper()
216
+ )
217
+
218
+ if os.path.isfile(model_path):
219
+ log.info(f"loading model from {model_path}")
220
+ weight = torch.load(model_path, map_location=torch.device('cpu'))
221
+ weight = {k[len('clip_project.'):]: v for k, v in weight.items()
222
+ if k.startswith('clip_project.')}
223
+ self.clip_project.load_state_dict(weight)
224
+
225
+ if fix_gpt:
226
+ log.info("fixing gpt parameters")
227
+ for param in self.gpt.parameters():
228
+ param.requires_grad_(False)
229
+
230
+ if self.scalar_output:
231
+ self.gpt.lm_head = nn.Linear(self.gpt.transformer.embed_dim, 1).to(self.device)
232
+
233
+ self.clip_project = self.clip_project.to(self.device)
234
+ if hasattr(self, 'label_project'):
235
+ self.label_project = self.label_project.to(self.device)
236
+
237
+ def get_mapper(self):
238
+ if self.use_ptuning_v2:
239
+ total_embd = self.match_n_layer * 2 * self.n_embd
240
+ module = MLP((self.prefix_size,
241
+ *[self.prefix_size
242
+ for i in range(self.num_layers)],
243
+ total_embd * self.prefix_length))
244
+ elif self.use_transformer_mapper:
245
+ log.info("using transformer mapper")
246
+ module = TransformerMapper(self.prefix_size, self.n_embd,
247
+ self.prefix_length, self.prefix_length, num_layers=self.num_layers) # 8)
248
+ else:
249
+ module = MLP((self.prefix_size,
250
+ *[(self.n_embd * self.prefix_length) // 2
251
+ for i in range(self.num_layers)],
252
+ self.n_embd * self.prefix_length))
253
+ return module
254
+
255
+ def get_encoder_loss(self, input_ids: torch.Tensor, features: torch.Tensor,
256
+ device = None):
257
+ input_ids = input_ids[:, :self.prefix_length].to(device)
258
+ embedding = self.gpt.transformer.wte(input_ids)
259
+ features = features.to(device)
260
+ prefix_projections = self.clip_project(features.type_as(embedding)).reshape(-1, self.prefix_length, self.n_embd)
261
+ fct = nn.MSELoss()
262
+ loss = fct(prefix_projections, embedding.detach())
263
+ return loss
264
+
265
+ def forward(self, *args, **kwargs):
266
+ if self.use_ptuning_v2:
267
+ return self.forward_prefix(*args, **kwargs)
268
+ else:
269
+ return self.forward_embedding(*args, **kwargs)
270
+
271
+ def forward_embedding(self, input_ids: torch.Tensor, features: torch.Tensor,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ labels: Optional[torch.Tensor] = None,
274
+ past_key_values = None, device = None, **kwargs):
275
+
276
+ if device is None:
277
+ device = self.device
278
+ input_ids = input_ids.to(device)
279
+ if features is not None:
280
+ features = features.to(device)
281
+ if attention_mask is not None:
282
+ attention_mask = attention_mask.to(device)
283
+ if labels is not None:
284
+ labels = labels.to(device)
285
+ use_labels = labels is not None and hasattr(self, 'label_project')
286
+
287
+ embedding = self.gpt.transformer.wte(input_ids)
288
+ embed_txt = embedding
289
+ prefix_length = self.prefix_length
290
+ if use_labels:
291
+ prefix_length += self.label_length
292
+ if past_key_values is None:
293
+ prefix_projections = self.clip_project(features.type_as(embedding)).reshape(-1, self.prefix_length, self.n_embd)
294
+ if use_labels:
295
+ label_projections = self.label_project(labels.long()).reshape(-1, self.label_length, self.n_embd)
296
+ prefix_projections = torch.cat((prefix_projections, label_projections), dim=1)
297
+ embedding = torch.cat((prefix_projections.to(embedding.dtype), embedding), dim=1)
298
+ if torch.is_tensor(attention_mask):
299
+ prefix_mask = torch.ones_like(attention_mask)[:, :1].repeat(1, prefix_length)
300
+ attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
301
+ outputs = self.gpt(inputs_embeds=embedding, attention_mask=attention_mask,
302
+ past_key_values=past_key_values,
303
+ return_dict=True,
304
+ output_attentions=False,
305
+ output_hidden_states=True)
306
+ if past_key_values is None:
307
+ outputs.logits = outputs.logits[:, prefix_length:]
308
+ return outputs
309
+
310
+ def forward_prefix(self, input_ids: torch.Tensor, features: torch.Tensor,
311
+ attention_mask: Optional[torch.Tensor] = None,
312
+ labels: Optional[torch.Tensor] = None,
313
+ past_key_values = None, device = None, **kwargs):
314
+
315
+ if device is None:
316
+ device = self.device
317
+ input_ids = input_ids.to(device)
318
+ if features is not None:
319
+ features = features.to(device)
320
+ if attention_mask is not None:
321
+ attention_mask = attention_mask.to(device)
322
+ if labels is not None:
323
+ labels = labels.to(device)
324
+ use_labels = labels is not None and hasattr(self, 'label_project')
325
+
326
+ prefix_length = self.prefix_length
327
+ if use_labels:
328
+ prefix_length += self.label_length
329
+ if past_key_values is None:
330
+ prefix_projections = self.clip_project(features.type_as(self.clip_project.model[0].weight))
331
+ prefix_projections = prefix_projections.reshape(-1, self.prefix_length,
332
+ self.match_n_layer * 2, self.match_n_head, self.match_n_embd)
333
+ if use_labels:
334
+ label_projections = self.label_project(labels.long())
335
+ label_projections = label_projections.reshape(-1, self.label_length,
336
+ self.match_n_layer * 2, self.match_n_head, self.match_n_embd)
337
+ prefix_projections = torch.cat((prefix_projections, label_projections), dim=1)
338
+ temp_control = prefix_projections
339
+ temp_control = self.dropout(temp_control)
340
+ past_key_values = temp_control.permute([2, 0, 3, 1, 4]).split(2)
341
+
342
+ if torch.is_tensor(attention_mask):
343
+ prefix_mask = torch.ones_like(attention_mask)[:, :1].repeat(1, prefix_length)
344
+ attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
345
+ outputs = self.gpt(input_ids=input_ids, attention_mask=attention_mask,
346
+ past_key_values=past_key_values,
347
+ return_dict=True,
348
+ output_attentions=False,
349
+ output_hidden_states=True)
350
+ if past_key_values is None:
351
+ outputs.logits = outputs.logits[:, prefix_length:]
352
+ return outputs
353
+
354
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
355
+ token_type_ids = kwargs.get("token_type_ids", None)
356
+ # only last token for inputs_ids if past is defined in kwargs
357
+ if past:
358
+ input_ids = input_ids[:, -1].unsqueeze(-1)
359
+ if token_type_ids is not None:
360
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
361
+
362
+ attention_mask = kwargs.get("attention_mask", None)
363
+ position_ids = kwargs.get("position_ids", None)
364
+ features = kwargs.get("features", None)
365
+ labels = kwargs.get("labels", None)
366
+
367
+ if attention_mask is not None and position_ids is None:
368
+ # create position_ids on the fly for batch generation
369
+ position_ids = attention_mask.long().cumsum(-1) - 1
370
+ position_ids.masked_fill_(attention_mask == 0, 1)
371
+ if past:
372
+ position_ids = position_ids[:, -1].unsqueeze(-1)
373
+ else:
374
+ position_ids = None
375
+
376
+ return {
377
+ "input_ids": input_ids,
378
+ "past_key_values": past,
379
+ "use_cache": kwargs.get("use_cache"),
380
+ "position_ids": position_ids,
381
+ "attention_mask": attention_mask,
382
+ "token_type_ids": token_type_ids,
383
+ "features": features,
384
+ "labels": labels,
385
+ }
load.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import yaml
7
+ import torch
8
+
9
+ from policy import Policy
10
+
11
+
12
+ logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def load_model_args(args):
17
+ checkpoint = Path(args.checkpoint + '.ckpt')
18
+ assert checkpoint.is_file(), f"no checkpoint file: {args.checkpoint}"
19
+ args_path = Path(args.checkpoint + '.json')
20
+ if args_path.is_file():
21
+ with open(args_path) as f:
22
+ hparams = json.load(f)
23
+ else:
24
+ args_path = Path(args.checkpoint + '.yaml')
25
+ with open(args_path) as f:
26
+ hparams = yaml.safe_load(f)
27
+ for key in ['init_model', 'clip_model_type', 'use_caption', 'use_style_reward', 'use_transformer_mapper',
28
+ 'prefix_length', 'clipcap_num_layers', 'use_ptuning_v2']:
29
+ if key in hparams:
30
+ setattr(args, key, hparams[key])
31
+ args.loaded_init_model = True
32
+ return args
33
+
34
+
35
+ def load_model(args, device, finetune=False):
36
+ log.info('loading model')
37
+ policy = Policy(model_name=args.init_model, temperature=1.0, device=device,
38
+ clipcap_path='None', fix_gpt=True,
39
+ label_path=args.label_path,
40
+ prefix_length=args.prefix_length,
41
+ clipcap_num_layers=args.clipcap_num_layers,
42
+ use_transformer_mapper=args.use_transformer_mapper,
43
+ model_weight='None', use_label_prefix=args.use_label_prefix)
44
+ ckpt = args.checkpoint + '.ckpt'
45
+ state = torch.load(ckpt)
46
+ policy_key = 'policy_model'
47
+ if policy_key in state:
48
+ policy.model.load_state_dict(state[policy_key])
49
+ else:
50
+ weights = state['state_dict']
51
+ key = 'policy.model.'
52
+ if not any(k for k in weights.keys() if k.startswith(key)):
53
+ key = 'model.model.'
54
+ weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)}
55
+ # weights = {k: v for k, v in weights.items() if k.startswith('clip_project.')}
56
+ policy.model.load_state_dict(weights, strict=False)
57
+ model = policy
58
+
59
+ model = model.to(device)
60
+ return model
policy.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from typing import Union, List, Dict, Optional
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM
6
+ from transformers.generation_logits_process import (
7
+ LogitsProcessorList,
8
+ NoBadWordsLogitsProcessor,
9
+ NoRepeatNGramLogitsProcessor,
10
+ )
11
+
12
+ from utils import (
13
+ NEGATIVE_INF, HALF_NEGATIVE_INF,
14
+ logits_to_entropy, mask_pad
15
+ )
16
+ from clipcap import ClipCap
17
+
18
+
19
+ class Policy(nn.Module):
20
+ def __init__(self, model_name, temperature, device, clipcap_path='', fix_gpt=False,
21
+ use_transformer_mapper: bool = False, use_ptuning_v2: bool = False,
22
+ prefix_length=10, clipcap_num_layers: int = 1,
23
+ label_path: str = '', model_weight: str = 'None', use_label_prefix: bool = False):
24
+ super().__init__()
25
+
26
+ self.device = device
27
+
28
+ self.model = ClipCap(model_name, device,
29
+ model_path=clipcap_path, fix_gpt=fix_gpt,
30
+ prefix_length=prefix_length,
31
+ num_layers=clipcap_num_layers,
32
+ label_path=label_path, model_weight=model_weight,
33
+ use_transformer_mapper=use_transformer_mapper,
34
+ use_ptuning_v2=use_ptuning_v2,
35
+ use_label_prefix=use_label_prefix)
36
+
37
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|endoftext|>")
38
+ self.model.gpt.config.pad_token_id = self.tokenizer.pad_token_id
39
+
40
+ self.temperature = temperature
41
+
42
+ def get_processor(self, no_repeat_ngram_size: int = 3):
43
+ logits_processor = LogitsProcessorList()
44
+ if no_repeat_ngram_size > 0:
45
+ logits_processor.append(NoRepeatNGramLogitsProcessor(ngram_size=no_repeat_ngram_size))
46
+ '''
47
+ logits_processor.append(NoBadWordsLogitsProcessor([[self.tokenizer.pad_token_id]],
48
+ self.tokenizer.pad_token_id))
49
+ '''
50
+ return logits_processor
51
+
52
+ def sample(self,
53
+ input_ids: torch.Tensor = None,
54
+ features: torch.Tensor = None,
55
+ attention_mask: torch.Tensor = None,
56
+ labels: Optional[torch.Tensor] = None,
57
+ max_len: int = 20,
58
+ sample: bool = True,
59
+ top_k: int = None,
60
+ top_p: float = None,
61
+ temperature: float = None,
62
+ no_repeat_ngram_size: int = 0,
63
+ invalidate_eos: bool = True,
64
+ device = None) -> Dict[str, Union[torch.Tensor, List[str]]]:
65
+ if device is None:
66
+ device = self.device
67
+ if temperature is None:
68
+ temperature = self.temperature
69
+
70
+ input_ids = input_ids.to(device)
71
+ attention_mask = attention_mask.to(device)
72
+
73
+ model_kwargs = {'attention_mask': attention_mask}
74
+ batch_size, input_seq_len = input_ids.shape
75
+
76
+ logits_processor = self.get_processor(no_repeat_ngram_size=no_repeat_ngram_size)
77
+
78
+ logits_warper = self.model.gpt._get_logits_warper(
79
+ top_k=top_k, top_p=top_p, temperature=temperature, num_beams=1
80
+ )
81
+
82
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
83
+ output_logprob = torch.zeros([batch_size, 0], device=device)
84
+ eos_logprobs = torch.zeros([batch_size, 0], device=device)
85
+ output_mask = torch.ones([batch_size, 0], dtype=torch.long, device=device)
86
+
87
+ self.model.eval()
88
+ with torch.no_grad():
89
+ for step in range(max_len):
90
+ # prepare model inputs
91
+ model_inputs = self.model.prepare_inputs_for_generation(input_ids,
92
+ features=features,
93
+ labels=labels,
94
+ **model_kwargs)
95
+
96
+ # forward pass to get next token
97
+ outputs = self.model(
98
+ **model_inputs,
99
+ device=device
100
+ )
101
+
102
+ # in the first decoding step, we want to use the 'real' last position for each sentence
103
+ if step == 0:
104
+ last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
105
+ next_token_logits = outputs.logits[range(batch_size), last_non_masked_idx, :]
106
+ else:
107
+ next_token_logits = outputs.logits[:, -1, :]
108
+
109
+ negative_inf = HALF_NEGATIVE_INF if next_token_logits.dtype == torch.half else NEGATIVE_INF
110
+ next_token_scores = logits_processor(input_ids, next_token_logits)
111
+ if invalidate_eos:
112
+ next_token_scores[:, self.tokenizer.eos_token_id] = negative_inf # no endoftext
113
+ log_prob = F.log_softmax(next_token_scores, dim=-1) # authentic sampling distribution
114
+ next_token_scores = logits_warper(input_ids, next_token_scores)
115
+ if sample:
116
+ # Temperature (higher temperature => more likely to sample low probability tokens)
117
+ probs = F.softmax(next_token_scores, dim=-1)
118
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
119
+ else:
120
+ # Greedy decoding
121
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
122
+
123
+ # finished sentences should have their next token be a padding token
124
+ next_tokens = next_tokens * unfinished_sequences + self.tokenizer.pad_token_id * (1 - unfinished_sequences)
125
+
126
+ # update output mask
127
+ output_mask = torch.cat([output_mask, unfinished_sequences[:, None]], dim=-1)
128
+ # update output log probability
129
+ eos_logprob = log_prob[:, self.tokenizer.eos_token_id]
130
+ eos_logprob = eos_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences)
131
+ eos_logprobs = torch.cat([eos_logprobs, eos_logprob[:, None]], dim=-1)
132
+
133
+ token_logprob = torch.gather(log_prob, 1, next_tokens[:, None]).squeeze(1)
134
+ token_logprob = token_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences)
135
+ output_logprob = torch.cat([output_logprob, token_logprob[:, None]], dim=-1)
136
+
137
+ # update generated ids, model inputs for next step
138
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
139
+ model_kwargs = self.model.gpt._update_model_kwargs_for_generation(
140
+ outputs, model_kwargs, is_encoder_decoder=self.model.gpt.config.is_encoder_decoder
141
+ )
142
+
143
+ # if eos_token was found in one sentence, set sentence to finished
144
+ unfinished_sequences = unfinished_sequences.mul((next_tokens != self.tokenizer.eos_token_id).long())
145
+
146
+ if unfinished_sequences.max() == 0:
147
+ break
148
+
149
+ response_ids = input_ids[:, input_seq_len:]
150
+ response_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
151
+ for output in response_ids]
152
+
153
+ prompt_ids = input_ids[:, :input_seq_len]
154
+ prompts = [self.tokenizer.decode(query, skip_special_tokens=True, clean_up_tokenization_spaces=True)
155
+ for query in prompt_ids]
156
+ eos_probs = eos_logprobs.exp()
157
+
158
+ return {
159
+ 'query/input_ids': prompt_ids,
160
+ 'query/text': prompts,
161
+ 'query/mask': attention_mask,
162
+ 'response/input_ids': response_ids,
163
+ 'response/text': response_text,
164
+ 'response/mask': output_mask,
165
+ 'response/log_prob': output_logprob,
166
+ 'response/eos_prob': eos_probs,
167
+ }
168
+
169
+ def forward_pass(self,
170
+ query_input_ids: torch.Tensor,
171
+ query_mask: torch.Tensor,
172
+ response_input_ids: torch.Tensor,
173
+ response_mask: torch.Tensor,
174
+ features: torch.Tensor,
175
+ labels: Optional[torch.Tensor] = None,
176
+ invalidate_eos: bool = True,
177
+ device = None):
178
+
179
+ if device is None:
180
+ device = self.device
181
+
182
+ batch_size, query_seq_len = query_input_ids.shape
183
+ input_ids = torch.cat([query_input_ids, response_input_ids], dim=-1)
184
+ attention_mask = torch.cat([query_mask, response_mask], dim=-1)
185
+
186
+ # forward pass to get next token
187
+ outputs = self.model(
188
+ input_ids,
189
+ features,
190
+ attention_mask,
191
+ labels,
192
+ device=device
193
+ )
194
+ # get the first logit
195
+ query_logits = outputs.logits[:, :query_seq_len, :]
196
+ last_non_masked_idx = torch.sum(query_mask, dim=1) - 1
197
+ first_logits = query_logits[range(batch_size), last_non_masked_idx, :]
198
+ # get the second to last logit
199
+ response_logits = outputs.logits[:, query_seq_len:-1, :]
200
+ logits = torch.cat([first_logits[:, None], response_logits], dim=1)
201
+
202
+ negative_inf = HALF_NEGATIVE_INF if logits.dtype == torch.half else NEGATIVE_INF
203
+ if invalidate_eos:
204
+ logits[:, :, self.tokenizer.eos_token_id] = negative_inf # no endoftext
205
+
206
+ log_prob = F.log_softmax(logits, dim=-1)
207
+ output_logprob = torch.gather(log_prob, 2, response_input_ids[:, :, None]).squeeze(2)
208
+ output_entropy = logits_to_entropy(logits)
209
+ eos_prob = F.softmax(logits, dim=-1)[:, :, self.tokenizer.eos_token_id]
210
+
211
+ pos_logit = torch.gather(logits, 2, response_input_ids[:, :, None]).squeeze(2)
212
+
213
+ return {
214
+ 'response/log_prob': mask_pad(output_logprob, response_mask),
215
+ 'response/eos_prob': mask_pad(eos_prob, response_mask),
216
+ 'response/entropy': mask_pad(output_entropy, response_mask),
217
+ 'response/pos_logit': mask_pad(pos_logit, response_mask),
218
+ 'response/logits': logits,
219
+ }
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ more-itertools
3
+ pyyaml==5.4
4
+ pillow
5
+ numpy
6
+ six
7
+ tqdm
8
+ ftfy
9
+ regex
10
+ huggingface-hub
11
+ ipdb
12
+ toml
13
+ torch==1.11.0
14
+ torchvision
15
+ tensorboard
16
+ transformers
17
+ clip-anytorch==2.4.0
18
+ gdown
19
+ gradio
run.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import platform
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from transformers import AutoModelForCausalLM
9
+ from PIL import Image
10
+ import numpy as np
11
+ from numpy import asarray
12
+ import gradio as gr
13
+ import clip
14
+
15
+ from arguments import get_args
16
+ from load import load_model_args, load_model
17
+ from utils import get_first_sentence
18
+
19
+
20
+ logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
21
+ log = logging.getLogger(__name__)
22
+
23
+
24
+ def prepare(args):
25
+ num_gpus = torch.cuda.device_count()
26
+ log.info(f'Detect {num_gpus} GPUS')
27
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28
+ args = load_model_args(args)
29
+
30
+ def load_style(args, checkpoint):
31
+ model = AutoModelForCausalLM.from_pretrained(args.init_model)
32
+ if checkpoint is not None and Path(checkpoint).is_file():
33
+ log.info("joint model: loading pretrained style generator")
34
+ state = torch.load(checkpoint)
35
+ if 'global_step' in state:
36
+ step = state['global_step']
37
+ log.info(f'trained for {step} steps')
38
+ weights = state['state_dict']
39
+ key = 'model.'
40
+ weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)}
41
+ model.load_state_dict(weights)
42
+ else:
43
+ log.info("joint model: loading vanila gpt")
44
+ return model
45
+
46
+ log.info(f'loading models')
47
+ joint_model = load_style(args, checkpoint=getattr(args, 'demo_joint_model_weight', 'None'))
48
+ joint_model = joint_model.to(device)
49
+ model = load_model(args, device)
50
+ tokenizer = model.tokenizer
51
+ log.info(f'loaded models ')
52
+
53
+ class Inferer:
54
+ def __init__(self, args, model, joint_model, tokenizer, device):
55
+ self.args = args
56
+ self.model = model
57
+ self.joint_model = joint_model
58
+ self.tokenizer = tokenizer
59
+ self.device = device
60
+
61
+ self.clip_model, self.clip_preprocess = clip.load(args.clip_model_type, device=device, jit=False)
62
+
63
+ def infer_joint(self, batch, window_size=10, vanilla_length=20, sample=False, temperature=0.7, **kwargs):
64
+ with torch.no_grad():
65
+ rollouts = self.model.sample(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'],
66
+ features=batch['features'], labels=None,
67
+ max_len=self.args.response_length, sample=sample,
68
+ no_repeat_ngram_size=self.args.infer_no_repeat_size,
69
+ invalidate_eos=False)
70
+ '''
71
+ query = rollouts['query/input_ids']
72
+ res = rollouts['response/input_ids']
73
+ gen1 = torch.cat([query, res], dim=1)
74
+ mask1 = torch.cat([rollouts['query/mask'], rollouts['response/mask']], dim=1)
75
+ '''
76
+ res = rollouts['response/text']
77
+ query = rollouts['query/text']
78
+ generations = [f'{q} {v.strip()}' for q, v in zip(query, res)]
79
+
80
+ cur_length = self.args.response_length
81
+ if vanilla_length > 0:
82
+ for i in range(math.ceil(vanilla_length / window_size)):
83
+ cur_length += window_size
84
+ generations = self.tokenizer(generations, padding=True, return_tensors='pt').to(self.device)
85
+ context = generations['input_ids'][:, :-window_size]
86
+ inputs = generations['input_ids'][:, -window_size:]
87
+ out = self.joint_model.generate(input_ids=inputs,
88
+ max_length=cur_length, sample=sample,
89
+ no_repeat_ngram_size=self.args.infer_no_repeat_size,
90
+ pad_token_id=self.tokenizer.eos_token_id)
91
+ out = torch.cat([context, out], dim=1)
92
+ text = [self.tokenizer.decode(v, skip_special_tokens=True, clean_up_tokenization_spaces=True)
93
+ for v in out]
94
+ # generations = [get_first_sentence(v) for v in generations]
95
+ generations = text
96
+ query = rollouts['query/text']
97
+ del rollouts
98
+ torch.cuda.empty_cache()
99
+ return query, generations
100
+
101
+ def get_feature(self, image):
102
+ image = self.clip_preprocess(image).unsqueeze(0).to(self.device)
103
+ feature = self.clip_model.encode_image(image)
104
+ return feature
105
+
106
+ def __call__(self, image, prompt, length=20, window_size=20, **kwargs):
107
+ window_size = min(window_size, length)
108
+ vanilla_length = max(0, length - self.args.response_length)
109
+ if not prompt:
110
+ prompt = 'The'
111
+ feature = self.get_feature(image)
112
+ feature = feature.unsqueeze(0).to(self.device)
113
+ batch = self.tokenizer(prompt, padding=True, return_tensors='pt').to(self.device)
114
+ batch['features'] = feature
115
+ query, generations = self.infer_joint(batch, window_size=window_size,
116
+ vanilla_length=vanilla_length, **kwargs)
117
+ # text = f'{query[0].strip()} {generations[0].strip()}'
118
+ text = generations[0].strip()
119
+ return text
120
+
121
+ inferer = Inferer(args, model, joint_model, tokenizer, device)
122
+ return inferer
123
+
124
+
125
+ class Runner:
126
+ def __init__(self, inferer):
127
+ self.inferer = inferer
128
+
129
+ def __call__(self, inp, prompt, length, window_size, sample):
130
+ # inp = inp.reshape((224, 224, 3))
131
+ img = Image.fromarray(np.uint8(inp))
132
+ text = self.inferer(img, prompt, length, window_size, sample=sample)
133
+ return prompt, text
134
+ # return inp, prompt, text
135
+
136
+
137
+ '''
138
+ # test_run
139
+ sample_img = asarray(Image.open('../data/coco/images/sample.jpg'))
140
+ img, _, text = run(sample_img, 'There lies', 50, 20, sample=False)
141
+ print('test_run:', text)
142
+ '''
143
+
144
+ def launch(examples=None):
145
+ args = get_args()
146
+ inferer = prepare(args)
147
+ runner = Runner(inferer)
148
+
149
+ iface = gr.Interface(
150
+ title="Demo for ESPER",
151
+ fn=runner.__call__,
152
+ inputs=[gr.components.Image(shape=(224, 224)),
153
+ gr.components.Textbox(label='prompt'),
154
+ gr.components.Slider(20, 120, step=1, label='length'),
155
+ gr.components.Slider(10, 100, step=1, label='window_size'),
156
+ gr.components.Checkbox(label='do sample')],
157
+ outputs=[gr.components.Textbox(label='prompt'),
158
+ gr.components.Textbox(label='generation')],
159
+ examples=examples
160
+ )
161
+ if args.port is not None:
162
+ print(f"running from {platform.node()}")
163
+ iface.launch(
164
+ server_name="0.0.0.0",
165
+ server_port=args.port
166
+ )
167
+ else:
168
+ iface.launch()
169
+
170
+
171
+ if __name__ == "__main__":
172
+ print(f"running from {platform.node()}")
173
+ launch()
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ NEGATIVE_INF = -100000.0
5
+ HALF_NEGATIVE_INF = -60000.0 # half precision
6
+
7
+
8
+ def get_first_sentence(txt, min_len=5):
9
+ eos = '<|endoftext|>'
10
+ eos_idx = txt.find(eos)
11
+ if eos_idx > 0:
12
+ txt = txt[eos_idx:]
13
+ txt = txt.replace('\n', ' ')
14
+ sents = txt.split('. ')
15
+ if len(sents[0]) >= min_len:
16
+ sent = f'{sents[0].strip()}.'
17
+ else:
18
+ sent = txt
19
+ return sent
20
+
21
+
22
+ def logits_to_entropy(logits):
23
+ distribution = torch.distributions.Categorical(logits=logits)
24
+ return distribution.entropy()
25
+
26
+
27
+ def mask_pad(value, mask):
28
+ return value * mask + NEGATIVE_INF * (1 - mask)