Spaces:
Runtime error
Runtime error
jiwan-chung
commited on
Commit
•
0bf81ba
1
Parent(s):
989194f
demo init
Browse files- .gitignore +3 -0
- app.py +43 -0
- arguments.py +58 -0
- clipcap.py +385 -0
- load.py +60 -0
- policy.py +219 -0
- requirements.txt +19 -0
- run.py +173 -0
- utils.py +28 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
data
|
2 |
+
flagged
|
3 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import gdown
|
4 |
+
# from PIL import Image
|
5 |
+
# from numpy import asarray
|
6 |
+
|
7 |
+
from run import launch
|
8 |
+
|
9 |
+
# download
|
10 |
+
if not Path('./data').is_dir():
|
11 |
+
url = 'https://drive.google.com/drive/folders/1hfHWDn5iXsdjB63E5zdZBAoRLWHQC3LD'
|
12 |
+
gdown.download_folder(url, quiet=True, use_cookies=False, output="./data/")
|
13 |
+
|
14 |
+
# example image from COCO data
|
15 |
+
image_urls = {
|
16 |
+
'108953': 'https://farm8.staticflickr.com/7160/6484651991_9d1eaa557a_z.jpg'
|
17 |
+
}
|
18 |
+
images = {}
|
19 |
+
for k, url in image_urls.items():
|
20 |
+
ext = Path(url).suffix
|
21 |
+
output = Path(f"data/images/{k}{ext}")
|
22 |
+
if not output.is_file():
|
23 |
+
output.parent.mkdir(exist_ok=True)
|
24 |
+
gdown.download(url, quiet=True, use_cookies=False, output=str(output))
|
25 |
+
images[k] = str(output)
|
26 |
+
|
27 |
+
'''
|
28 |
+
for k, v in images.items():
|
29 |
+
with Image.open(v) as image:
|
30 |
+
# image = asarray(image)
|
31 |
+
images[k] = image
|
32 |
+
'''
|
33 |
+
|
34 |
+
|
35 |
+
examples = [[
|
36 |
+
v,
|
37 |
+
'My favourite recipe:',
|
38 |
+
20,
|
39 |
+
10,
|
40 |
+
False
|
41 |
+
] for v in images.values()]
|
42 |
+
|
43 |
+
launch(examples)
|
arguments.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
11 |
+
log = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def get_args():
|
15 |
+
parser = argparse.ArgumentParser(description='ESPER')
|
16 |
+
|
17 |
+
|
18 |
+
parser.add_argument(
|
19 |
+
'--init-model', type=str, default='gpt2', help='language model used for policy.')
|
20 |
+
parser.add_argument(
|
21 |
+
'--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
|
22 |
+
parser.add_argument(
|
23 |
+
'--checkpoint', type=str, default='./data/esper_demo/ckpt/gpt2_style', help='checkpoint file path')
|
24 |
+
|
25 |
+
parser.add_argument(
|
26 |
+
'--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
|
27 |
+
parser.add_argument(
|
28 |
+
'--clipcap_num_layers', type=int, default=1, help='num_layers for the visual mapper')
|
29 |
+
parser.add_argument(
|
30 |
+
'--use_transformer_mapper', action='store_true', default=False, help='use transformer mapper instead of mlp')
|
31 |
+
parser.add_argument(
|
32 |
+
'--use_label_prefix', action='store_true', default=False, help='label as prefixes')
|
33 |
+
parser.add_argument(
|
34 |
+
'--clip_model_type', type=str, default='ViT-B/32', help='clip backbone type')
|
35 |
+
|
36 |
+
parser.add_argument(
|
37 |
+
'--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
|
38 |
+
parser.add_argument(
|
39 |
+
'--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
|
40 |
+
parser.add_argument(
|
41 |
+
'--num-gpus', type=int, default=None, help='number of gpus. use all available if none')
|
42 |
+
parser.add_argument(
|
43 |
+
'--port', type=int, default=None, help="port for the demo server")
|
44 |
+
|
45 |
+
args = parser.parse_args()
|
46 |
+
args.cuda = torch.cuda.is_available()
|
47 |
+
|
48 |
+
if args.use_label_prefix:
|
49 |
+
log.info(f'using label prefix')
|
50 |
+
num_gpus = torch.cuda.device_count()
|
51 |
+
if args.num_gpus is None:
|
52 |
+
args.num_gpus = num_gpus
|
53 |
+
else:
|
54 |
+
args.num_gpus = min(num_gpus, args.num_gpus)
|
55 |
+
|
56 |
+
if args.checkpoint is not None:
|
57 |
+
args.checkpoint = str(Path(args.checkpoint).resolve())
|
58 |
+
return args
|
clipcap.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import logging
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Tuple, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM
|
12 |
+
|
13 |
+
|
14 |
+
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
15 |
+
log = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def load_weights(self, Module, path, name, default_name, prev_name=None, **kwargs):
|
19 |
+
hparams = None
|
20 |
+
assert isinstance(default_name, str), f'invalid default transformer name: {default_name}'
|
21 |
+
model = get_transformer_module(Module, default_name, **kwargs)
|
22 |
+
setattr(self, name, model)
|
23 |
+
return hparams
|
24 |
+
|
25 |
+
|
26 |
+
def get_transformer_module(Module, default_name, **kwargs):
|
27 |
+
if default_name == 'EleutherAI/gpt-j-6B':
|
28 |
+
kwargs = {**kwargs, **dict(revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)}
|
29 |
+
model = Module.from_pretrained(default_name, **kwargs)
|
30 |
+
return model
|
31 |
+
|
32 |
+
|
33 |
+
class MLP(nn.Module):
|
34 |
+
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
35 |
+
super(MLP, self).__init__()
|
36 |
+
|
37 |
+
self.divider = math.sqrt(sizes[-1] / sizes[0])
|
38 |
+
layers = []
|
39 |
+
for i in range(len(sizes) - 1):
|
40 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
41 |
+
if i < len(sizes) - 2:
|
42 |
+
layers.append(act())
|
43 |
+
self.model = nn.Sequential(*layers)
|
44 |
+
|
45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
46 |
+
x = x / self.divider # scaling for the initial stability
|
47 |
+
x = self.model(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class MlpTransformer(nn.Module):
|
52 |
+
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=F.relu, dropout=0.):
|
53 |
+
super().__init__()
|
54 |
+
out_d = out_d if out_d is not None else in_dim
|
55 |
+
self.fc1 = nn.Linear(in_dim, h_dim)
|
56 |
+
self.act = act
|
57 |
+
self.fc2 = nn.Linear(h_dim, out_d)
|
58 |
+
self.dropout = nn.Dropout(dropout)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = self.fc1(x)
|
62 |
+
x = self.act(x)
|
63 |
+
x = self.dropout(x)
|
64 |
+
x = self.fc2(x)
|
65 |
+
x = self.dropout(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class MultiHeadAttention(nn.Module):
|
70 |
+
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
|
71 |
+
super().__init__()
|
72 |
+
self.num_heads = num_heads
|
73 |
+
head_dim = dim_self // num_heads
|
74 |
+
self.scale = head_dim ** -0.5
|
75 |
+
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
|
76 |
+
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
|
77 |
+
self.project = nn.Linear(dim_self, dim_self)
|
78 |
+
self.dropout = nn.Dropout(dropout)
|
79 |
+
|
80 |
+
def forward(self, x, y=None, mask=None):
|
81 |
+
y = y if y is not None else x
|
82 |
+
b, n, c = x.shape
|
83 |
+
_, m, d = y.shape
|
84 |
+
# b n h dh
|
85 |
+
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
|
86 |
+
# b m 2 h dh
|
87 |
+
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
|
88 |
+
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
|
89 |
+
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
|
90 |
+
if mask is not None:
|
91 |
+
if mask.dim() == 2:
|
92 |
+
mask = mask.unsqueeze(1)
|
93 |
+
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
|
94 |
+
attention = attention.softmax(dim=2)
|
95 |
+
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
|
96 |
+
out = self.project(out)
|
97 |
+
return out, attention
|
98 |
+
|
99 |
+
|
100 |
+
class TransformerLayer(nn.Module):
|
101 |
+
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=F.relu,
|
102 |
+
norm_layer: nn.Module = nn.LayerNorm):
|
103 |
+
super().__init__()
|
104 |
+
self.norm1 = norm_layer(dim_self)
|
105 |
+
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
|
106 |
+
self.norm2 = norm_layer(dim_self)
|
107 |
+
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
|
108 |
+
|
109 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
110 |
+
x_, attention = self.attn(self.norm1(x), y, mask)
|
111 |
+
x = x + x_
|
112 |
+
x = x + self.mlp(self.norm2(x))
|
113 |
+
return x, attention
|
114 |
+
|
115 |
+
def forward(self, x, y=None, mask=None):
|
116 |
+
x = x + self.attn(self.norm1(x), y, mask)[0]
|
117 |
+
x = x + self.mlp(self.norm2(x))
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
class Transformer(nn.Module):
|
122 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
123 |
+
attentions = []
|
124 |
+
for layer in self.layers:
|
125 |
+
x, att = layer.forward_with_attention(x, y, mask)
|
126 |
+
attentions.append(att)
|
127 |
+
return x, attentions
|
128 |
+
|
129 |
+
def forward(self, x, y=None, mask=None):
|
130 |
+
for i, layer in enumerate(self.layers):
|
131 |
+
if i % 2 == 0 and self.enc_dec: # cross
|
132 |
+
x = layer(x, y)
|
133 |
+
elif self.enc_dec: # self
|
134 |
+
x = layer(x, x, mask)
|
135 |
+
else: # self or cross
|
136 |
+
x = layer(x, y, mask)
|
137 |
+
return x
|
138 |
+
|
139 |
+
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
|
140 |
+
mlp_ratio: float = 2., act=F.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
|
141 |
+
super(Transformer, self).__init__()
|
142 |
+
dim_ref = dim_ref if dim_ref is not None else dim_self
|
143 |
+
self.enc_dec = enc_dec
|
144 |
+
if enc_dec:
|
145 |
+
num_layers = num_layers * 2
|
146 |
+
layers = []
|
147 |
+
for i in range(num_layers):
|
148 |
+
if i % 2 == 0 and enc_dec: # cross
|
149 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
150 |
+
elif enc_dec: # self
|
151 |
+
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
152 |
+
else: # self or cross
|
153 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
154 |
+
self.layers = nn.ModuleList(layers)
|
155 |
+
|
156 |
+
|
157 |
+
class TransformerMapper(nn.Module):
|
158 |
+
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int = 10,
|
159 |
+
clip_length: int = 10, num_layers: int = 8):
|
160 |
+
super(TransformerMapper, self).__init__()
|
161 |
+
self.clip_length = clip_length
|
162 |
+
self.transformer = Transformer(dim_embedding, 8, num_layers)
|
163 |
+
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
|
164 |
+
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
|
168 |
+
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
|
169 |
+
prefix = torch.cat((x, prefix), dim=1)
|
170 |
+
out = self.transformer(prefix)[:, self.clip_length:]
|
171 |
+
return out
|
172 |
+
|
173 |
+
|
174 |
+
class ClipCap(nn.Module):
|
175 |
+
def __init__(self, model_name, device, prefix_length: int = 10, clip_length: int = 40, prefix_size: int = 512,
|
176 |
+
num_layers: int = 1, model_path: str = '', fix_gpt: bool = False,
|
177 |
+
use_label_prefix: bool = False, label_path: str = '', label_length: int = 10,
|
178 |
+
use_transformer_mapper: bool = False, use_ptuning_v2: bool = False,
|
179 |
+
dropout: float = 0,
|
180 |
+
model_weight: str = '', scalar_output: bool = False):
|
181 |
+
super(ClipCap, self).__init__()
|
182 |
+
|
183 |
+
self.prefix_length = prefix_length
|
184 |
+
self.prefix_size = prefix_size
|
185 |
+
self.label_length = label_length
|
186 |
+
self.scalar_output = scalar_output
|
187 |
+
self.num_layers = num_layers
|
188 |
+
self.use_transformer_mapper = use_transformer_mapper
|
189 |
+
self.use_ptuning_v2 = use_ptuning_v2
|
190 |
+
|
191 |
+
self.dropout = nn.Dropout(dropout)
|
192 |
+
|
193 |
+
hparams = load_weights(self, AutoModelForCausalLM, model_weight, 'gpt', model_name,
|
194 |
+
prev_name='model')
|
195 |
+
|
196 |
+
self.device = device
|
197 |
+
self.gpt = self.gpt.to(self.device)
|
198 |
+
|
199 |
+
config = self.gpt.config
|
200 |
+
self.match_n_layer = getattr(config, 'n_layer', getattr(config, 'num_layers', None)) # gpt2 vs. gpt_neo
|
201 |
+
self.match_n_head = getattr(config, 'n_head', getattr(config, 'num_heads', None))
|
202 |
+
self.n_embd = getattr(config, 'n_embd', getattr(config, 'hidden_size', None))
|
203 |
+
self.match_n_embd = self.n_embd // self.match_n_head
|
204 |
+
|
205 |
+
self.clip_project = self.get_mapper()
|
206 |
+
|
207 |
+
if Path(label_path).is_file():
|
208 |
+
with open(label_path) as f:
|
209 |
+
labels = json.load(f)
|
210 |
+
self.labels = {i: v for v, i in labels.items()}
|
211 |
+
if not use_label_prefix:
|
212 |
+
log.info("adding label projections")
|
213 |
+
self.label_project = nn.Sequential(
|
214 |
+
nn.Embedding(len(self.labels), self.prefix_size),
|
215 |
+
self.get_mapper()
|
216 |
+
)
|
217 |
+
|
218 |
+
if os.path.isfile(model_path):
|
219 |
+
log.info(f"loading model from {model_path}")
|
220 |
+
weight = torch.load(model_path, map_location=torch.device('cpu'))
|
221 |
+
weight = {k[len('clip_project.'):]: v for k, v in weight.items()
|
222 |
+
if k.startswith('clip_project.')}
|
223 |
+
self.clip_project.load_state_dict(weight)
|
224 |
+
|
225 |
+
if fix_gpt:
|
226 |
+
log.info("fixing gpt parameters")
|
227 |
+
for param in self.gpt.parameters():
|
228 |
+
param.requires_grad_(False)
|
229 |
+
|
230 |
+
if self.scalar_output:
|
231 |
+
self.gpt.lm_head = nn.Linear(self.gpt.transformer.embed_dim, 1).to(self.device)
|
232 |
+
|
233 |
+
self.clip_project = self.clip_project.to(self.device)
|
234 |
+
if hasattr(self, 'label_project'):
|
235 |
+
self.label_project = self.label_project.to(self.device)
|
236 |
+
|
237 |
+
def get_mapper(self):
|
238 |
+
if self.use_ptuning_v2:
|
239 |
+
total_embd = self.match_n_layer * 2 * self.n_embd
|
240 |
+
module = MLP((self.prefix_size,
|
241 |
+
*[self.prefix_size
|
242 |
+
for i in range(self.num_layers)],
|
243 |
+
total_embd * self.prefix_length))
|
244 |
+
elif self.use_transformer_mapper:
|
245 |
+
log.info("using transformer mapper")
|
246 |
+
module = TransformerMapper(self.prefix_size, self.n_embd,
|
247 |
+
self.prefix_length, self.prefix_length, num_layers=self.num_layers) # 8)
|
248 |
+
else:
|
249 |
+
module = MLP((self.prefix_size,
|
250 |
+
*[(self.n_embd * self.prefix_length) // 2
|
251 |
+
for i in range(self.num_layers)],
|
252 |
+
self.n_embd * self.prefix_length))
|
253 |
+
return module
|
254 |
+
|
255 |
+
def get_encoder_loss(self, input_ids: torch.Tensor, features: torch.Tensor,
|
256 |
+
device = None):
|
257 |
+
input_ids = input_ids[:, :self.prefix_length].to(device)
|
258 |
+
embedding = self.gpt.transformer.wte(input_ids)
|
259 |
+
features = features.to(device)
|
260 |
+
prefix_projections = self.clip_project(features.type_as(embedding)).reshape(-1, self.prefix_length, self.n_embd)
|
261 |
+
fct = nn.MSELoss()
|
262 |
+
loss = fct(prefix_projections, embedding.detach())
|
263 |
+
return loss
|
264 |
+
|
265 |
+
def forward(self, *args, **kwargs):
|
266 |
+
if self.use_ptuning_v2:
|
267 |
+
return self.forward_prefix(*args, **kwargs)
|
268 |
+
else:
|
269 |
+
return self.forward_embedding(*args, **kwargs)
|
270 |
+
|
271 |
+
def forward_embedding(self, input_ids: torch.Tensor, features: torch.Tensor,
|
272 |
+
attention_mask: Optional[torch.Tensor] = None,
|
273 |
+
labels: Optional[torch.Tensor] = None,
|
274 |
+
past_key_values = None, device = None, **kwargs):
|
275 |
+
|
276 |
+
if device is None:
|
277 |
+
device = self.device
|
278 |
+
input_ids = input_ids.to(device)
|
279 |
+
if features is not None:
|
280 |
+
features = features.to(device)
|
281 |
+
if attention_mask is not None:
|
282 |
+
attention_mask = attention_mask.to(device)
|
283 |
+
if labels is not None:
|
284 |
+
labels = labels.to(device)
|
285 |
+
use_labels = labels is not None and hasattr(self, 'label_project')
|
286 |
+
|
287 |
+
embedding = self.gpt.transformer.wte(input_ids)
|
288 |
+
embed_txt = embedding
|
289 |
+
prefix_length = self.prefix_length
|
290 |
+
if use_labels:
|
291 |
+
prefix_length += self.label_length
|
292 |
+
if past_key_values is None:
|
293 |
+
prefix_projections = self.clip_project(features.type_as(embedding)).reshape(-1, self.prefix_length, self.n_embd)
|
294 |
+
if use_labels:
|
295 |
+
label_projections = self.label_project(labels.long()).reshape(-1, self.label_length, self.n_embd)
|
296 |
+
prefix_projections = torch.cat((prefix_projections, label_projections), dim=1)
|
297 |
+
embedding = torch.cat((prefix_projections.to(embedding.dtype), embedding), dim=1)
|
298 |
+
if torch.is_tensor(attention_mask):
|
299 |
+
prefix_mask = torch.ones_like(attention_mask)[:, :1].repeat(1, prefix_length)
|
300 |
+
attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
|
301 |
+
outputs = self.gpt(inputs_embeds=embedding, attention_mask=attention_mask,
|
302 |
+
past_key_values=past_key_values,
|
303 |
+
return_dict=True,
|
304 |
+
output_attentions=False,
|
305 |
+
output_hidden_states=True)
|
306 |
+
if past_key_values is None:
|
307 |
+
outputs.logits = outputs.logits[:, prefix_length:]
|
308 |
+
return outputs
|
309 |
+
|
310 |
+
def forward_prefix(self, input_ids: torch.Tensor, features: torch.Tensor,
|
311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
312 |
+
labels: Optional[torch.Tensor] = None,
|
313 |
+
past_key_values = None, device = None, **kwargs):
|
314 |
+
|
315 |
+
if device is None:
|
316 |
+
device = self.device
|
317 |
+
input_ids = input_ids.to(device)
|
318 |
+
if features is not None:
|
319 |
+
features = features.to(device)
|
320 |
+
if attention_mask is not None:
|
321 |
+
attention_mask = attention_mask.to(device)
|
322 |
+
if labels is not None:
|
323 |
+
labels = labels.to(device)
|
324 |
+
use_labels = labels is not None and hasattr(self, 'label_project')
|
325 |
+
|
326 |
+
prefix_length = self.prefix_length
|
327 |
+
if use_labels:
|
328 |
+
prefix_length += self.label_length
|
329 |
+
if past_key_values is None:
|
330 |
+
prefix_projections = self.clip_project(features.type_as(self.clip_project.model[0].weight))
|
331 |
+
prefix_projections = prefix_projections.reshape(-1, self.prefix_length,
|
332 |
+
self.match_n_layer * 2, self.match_n_head, self.match_n_embd)
|
333 |
+
if use_labels:
|
334 |
+
label_projections = self.label_project(labels.long())
|
335 |
+
label_projections = label_projections.reshape(-1, self.label_length,
|
336 |
+
self.match_n_layer * 2, self.match_n_head, self.match_n_embd)
|
337 |
+
prefix_projections = torch.cat((prefix_projections, label_projections), dim=1)
|
338 |
+
temp_control = prefix_projections
|
339 |
+
temp_control = self.dropout(temp_control)
|
340 |
+
past_key_values = temp_control.permute([2, 0, 3, 1, 4]).split(2)
|
341 |
+
|
342 |
+
if torch.is_tensor(attention_mask):
|
343 |
+
prefix_mask = torch.ones_like(attention_mask)[:, :1].repeat(1, prefix_length)
|
344 |
+
attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
|
345 |
+
outputs = self.gpt(input_ids=input_ids, attention_mask=attention_mask,
|
346 |
+
past_key_values=past_key_values,
|
347 |
+
return_dict=True,
|
348 |
+
output_attentions=False,
|
349 |
+
output_hidden_states=True)
|
350 |
+
if past_key_values is None:
|
351 |
+
outputs.logits = outputs.logits[:, prefix_length:]
|
352 |
+
return outputs
|
353 |
+
|
354 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
355 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
356 |
+
# only last token for inputs_ids if past is defined in kwargs
|
357 |
+
if past:
|
358 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
359 |
+
if token_type_ids is not None:
|
360 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
361 |
+
|
362 |
+
attention_mask = kwargs.get("attention_mask", None)
|
363 |
+
position_ids = kwargs.get("position_ids", None)
|
364 |
+
features = kwargs.get("features", None)
|
365 |
+
labels = kwargs.get("labels", None)
|
366 |
+
|
367 |
+
if attention_mask is not None and position_ids is None:
|
368 |
+
# create position_ids on the fly for batch generation
|
369 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
370 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
371 |
+
if past:
|
372 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
373 |
+
else:
|
374 |
+
position_ids = None
|
375 |
+
|
376 |
+
return {
|
377 |
+
"input_ids": input_ids,
|
378 |
+
"past_key_values": past,
|
379 |
+
"use_cache": kwargs.get("use_cache"),
|
380 |
+
"position_ids": position_ids,
|
381 |
+
"attention_mask": attention_mask,
|
382 |
+
"token_type_ids": token_type_ids,
|
383 |
+
"features": features,
|
384 |
+
"labels": labels,
|
385 |
+
}
|
load.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import yaml
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from policy import Policy
|
10 |
+
|
11 |
+
|
12 |
+
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
13 |
+
log = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model_args(args):
|
17 |
+
checkpoint = Path(args.checkpoint + '.ckpt')
|
18 |
+
assert checkpoint.is_file(), f"no checkpoint file: {args.checkpoint}"
|
19 |
+
args_path = Path(args.checkpoint + '.json')
|
20 |
+
if args_path.is_file():
|
21 |
+
with open(args_path) as f:
|
22 |
+
hparams = json.load(f)
|
23 |
+
else:
|
24 |
+
args_path = Path(args.checkpoint + '.yaml')
|
25 |
+
with open(args_path) as f:
|
26 |
+
hparams = yaml.safe_load(f)
|
27 |
+
for key in ['init_model', 'clip_model_type', 'use_caption', 'use_style_reward', 'use_transformer_mapper',
|
28 |
+
'prefix_length', 'clipcap_num_layers', 'use_ptuning_v2']:
|
29 |
+
if key in hparams:
|
30 |
+
setattr(args, key, hparams[key])
|
31 |
+
args.loaded_init_model = True
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def load_model(args, device, finetune=False):
|
36 |
+
log.info('loading model')
|
37 |
+
policy = Policy(model_name=args.init_model, temperature=1.0, device=device,
|
38 |
+
clipcap_path='None', fix_gpt=True,
|
39 |
+
label_path=args.label_path,
|
40 |
+
prefix_length=args.prefix_length,
|
41 |
+
clipcap_num_layers=args.clipcap_num_layers,
|
42 |
+
use_transformer_mapper=args.use_transformer_mapper,
|
43 |
+
model_weight='None', use_label_prefix=args.use_label_prefix)
|
44 |
+
ckpt = args.checkpoint + '.ckpt'
|
45 |
+
state = torch.load(ckpt)
|
46 |
+
policy_key = 'policy_model'
|
47 |
+
if policy_key in state:
|
48 |
+
policy.model.load_state_dict(state[policy_key])
|
49 |
+
else:
|
50 |
+
weights = state['state_dict']
|
51 |
+
key = 'policy.model.'
|
52 |
+
if not any(k for k in weights.keys() if k.startswith(key)):
|
53 |
+
key = 'model.model.'
|
54 |
+
weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)}
|
55 |
+
# weights = {k: v for k, v in weights.items() if k.startswith('clip_project.')}
|
56 |
+
policy.model.load_state_dict(weights, strict=False)
|
57 |
+
model = policy
|
58 |
+
|
59 |
+
model = model.to(device)
|
60 |
+
return model
|
policy.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Union, List, Dict, Optional
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM
|
6 |
+
from transformers.generation_logits_process import (
|
7 |
+
LogitsProcessorList,
|
8 |
+
NoBadWordsLogitsProcessor,
|
9 |
+
NoRepeatNGramLogitsProcessor,
|
10 |
+
)
|
11 |
+
|
12 |
+
from utils import (
|
13 |
+
NEGATIVE_INF, HALF_NEGATIVE_INF,
|
14 |
+
logits_to_entropy, mask_pad
|
15 |
+
)
|
16 |
+
from clipcap import ClipCap
|
17 |
+
|
18 |
+
|
19 |
+
class Policy(nn.Module):
|
20 |
+
def __init__(self, model_name, temperature, device, clipcap_path='', fix_gpt=False,
|
21 |
+
use_transformer_mapper: bool = False, use_ptuning_v2: bool = False,
|
22 |
+
prefix_length=10, clipcap_num_layers: int = 1,
|
23 |
+
label_path: str = '', model_weight: str = 'None', use_label_prefix: bool = False):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
self.model = ClipCap(model_name, device,
|
29 |
+
model_path=clipcap_path, fix_gpt=fix_gpt,
|
30 |
+
prefix_length=prefix_length,
|
31 |
+
num_layers=clipcap_num_layers,
|
32 |
+
label_path=label_path, model_weight=model_weight,
|
33 |
+
use_transformer_mapper=use_transformer_mapper,
|
34 |
+
use_ptuning_v2=use_ptuning_v2,
|
35 |
+
use_label_prefix=use_label_prefix)
|
36 |
+
|
37 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|endoftext|>")
|
38 |
+
self.model.gpt.config.pad_token_id = self.tokenizer.pad_token_id
|
39 |
+
|
40 |
+
self.temperature = temperature
|
41 |
+
|
42 |
+
def get_processor(self, no_repeat_ngram_size: int = 3):
|
43 |
+
logits_processor = LogitsProcessorList()
|
44 |
+
if no_repeat_ngram_size > 0:
|
45 |
+
logits_processor.append(NoRepeatNGramLogitsProcessor(ngram_size=no_repeat_ngram_size))
|
46 |
+
'''
|
47 |
+
logits_processor.append(NoBadWordsLogitsProcessor([[self.tokenizer.pad_token_id]],
|
48 |
+
self.tokenizer.pad_token_id))
|
49 |
+
'''
|
50 |
+
return logits_processor
|
51 |
+
|
52 |
+
def sample(self,
|
53 |
+
input_ids: torch.Tensor = None,
|
54 |
+
features: torch.Tensor = None,
|
55 |
+
attention_mask: torch.Tensor = None,
|
56 |
+
labels: Optional[torch.Tensor] = None,
|
57 |
+
max_len: int = 20,
|
58 |
+
sample: bool = True,
|
59 |
+
top_k: int = None,
|
60 |
+
top_p: float = None,
|
61 |
+
temperature: float = None,
|
62 |
+
no_repeat_ngram_size: int = 0,
|
63 |
+
invalidate_eos: bool = True,
|
64 |
+
device = None) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
65 |
+
if device is None:
|
66 |
+
device = self.device
|
67 |
+
if temperature is None:
|
68 |
+
temperature = self.temperature
|
69 |
+
|
70 |
+
input_ids = input_ids.to(device)
|
71 |
+
attention_mask = attention_mask.to(device)
|
72 |
+
|
73 |
+
model_kwargs = {'attention_mask': attention_mask}
|
74 |
+
batch_size, input_seq_len = input_ids.shape
|
75 |
+
|
76 |
+
logits_processor = self.get_processor(no_repeat_ngram_size=no_repeat_ngram_size)
|
77 |
+
|
78 |
+
logits_warper = self.model.gpt._get_logits_warper(
|
79 |
+
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=1
|
80 |
+
)
|
81 |
+
|
82 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
83 |
+
output_logprob = torch.zeros([batch_size, 0], device=device)
|
84 |
+
eos_logprobs = torch.zeros([batch_size, 0], device=device)
|
85 |
+
output_mask = torch.ones([batch_size, 0], dtype=torch.long, device=device)
|
86 |
+
|
87 |
+
self.model.eval()
|
88 |
+
with torch.no_grad():
|
89 |
+
for step in range(max_len):
|
90 |
+
# prepare model inputs
|
91 |
+
model_inputs = self.model.prepare_inputs_for_generation(input_ids,
|
92 |
+
features=features,
|
93 |
+
labels=labels,
|
94 |
+
**model_kwargs)
|
95 |
+
|
96 |
+
# forward pass to get next token
|
97 |
+
outputs = self.model(
|
98 |
+
**model_inputs,
|
99 |
+
device=device
|
100 |
+
)
|
101 |
+
|
102 |
+
# in the first decoding step, we want to use the 'real' last position for each sentence
|
103 |
+
if step == 0:
|
104 |
+
last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
|
105 |
+
next_token_logits = outputs.logits[range(batch_size), last_non_masked_idx, :]
|
106 |
+
else:
|
107 |
+
next_token_logits = outputs.logits[:, -1, :]
|
108 |
+
|
109 |
+
negative_inf = HALF_NEGATIVE_INF if next_token_logits.dtype == torch.half else NEGATIVE_INF
|
110 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
111 |
+
if invalidate_eos:
|
112 |
+
next_token_scores[:, self.tokenizer.eos_token_id] = negative_inf # no endoftext
|
113 |
+
log_prob = F.log_softmax(next_token_scores, dim=-1) # authentic sampling distribution
|
114 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
115 |
+
if sample:
|
116 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
117 |
+
probs = F.softmax(next_token_scores, dim=-1)
|
118 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
119 |
+
else:
|
120 |
+
# Greedy decoding
|
121 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
122 |
+
|
123 |
+
# finished sentences should have their next token be a padding token
|
124 |
+
next_tokens = next_tokens * unfinished_sequences + self.tokenizer.pad_token_id * (1 - unfinished_sequences)
|
125 |
+
|
126 |
+
# update output mask
|
127 |
+
output_mask = torch.cat([output_mask, unfinished_sequences[:, None]], dim=-1)
|
128 |
+
# update output log probability
|
129 |
+
eos_logprob = log_prob[:, self.tokenizer.eos_token_id]
|
130 |
+
eos_logprob = eos_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences)
|
131 |
+
eos_logprobs = torch.cat([eos_logprobs, eos_logprob[:, None]], dim=-1)
|
132 |
+
|
133 |
+
token_logprob = torch.gather(log_prob, 1, next_tokens[:, None]).squeeze(1)
|
134 |
+
token_logprob = token_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences)
|
135 |
+
output_logprob = torch.cat([output_logprob, token_logprob[:, None]], dim=-1)
|
136 |
+
|
137 |
+
# update generated ids, model inputs for next step
|
138 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
139 |
+
model_kwargs = self.model.gpt._update_model_kwargs_for_generation(
|
140 |
+
outputs, model_kwargs, is_encoder_decoder=self.model.gpt.config.is_encoder_decoder
|
141 |
+
)
|
142 |
+
|
143 |
+
# if eos_token was found in one sentence, set sentence to finished
|
144 |
+
unfinished_sequences = unfinished_sequences.mul((next_tokens != self.tokenizer.eos_token_id).long())
|
145 |
+
|
146 |
+
if unfinished_sequences.max() == 0:
|
147 |
+
break
|
148 |
+
|
149 |
+
response_ids = input_ids[:, input_seq_len:]
|
150 |
+
response_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
151 |
+
for output in response_ids]
|
152 |
+
|
153 |
+
prompt_ids = input_ids[:, :input_seq_len]
|
154 |
+
prompts = [self.tokenizer.decode(query, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
155 |
+
for query in prompt_ids]
|
156 |
+
eos_probs = eos_logprobs.exp()
|
157 |
+
|
158 |
+
return {
|
159 |
+
'query/input_ids': prompt_ids,
|
160 |
+
'query/text': prompts,
|
161 |
+
'query/mask': attention_mask,
|
162 |
+
'response/input_ids': response_ids,
|
163 |
+
'response/text': response_text,
|
164 |
+
'response/mask': output_mask,
|
165 |
+
'response/log_prob': output_logprob,
|
166 |
+
'response/eos_prob': eos_probs,
|
167 |
+
}
|
168 |
+
|
169 |
+
def forward_pass(self,
|
170 |
+
query_input_ids: torch.Tensor,
|
171 |
+
query_mask: torch.Tensor,
|
172 |
+
response_input_ids: torch.Tensor,
|
173 |
+
response_mask: torch.Tensor,
|
174 |
+
features: torch.Tensor,
|
175 |
+
labels: Optional[torch.Tensor] = None,
|
176 |
+
invalidate_eos: bool = True,
|
177 |
+
device = None):
|
178 |
+
|
179 |
+
if device is None:
|
180 |
+
device = self.device
|
181 |
+
|
182 |
+
batch_size, query_seq_len = query_input_ids.shape
|
183 |
+
input_ids = torch.cat([query_input_ids, response_input_ids], dim=-1)
|
184 |
+
attention_mask = torch.cat([query_mask, response_mask], dim=-1)
|
185 |
+
|
186 |
+
# forward pass to get next token
|
187 |
+
outputs = self.model(
|
188 |
+
input_ids,
|
189 |
+
features,
|
190 |
+
attention_mask,
|
191 |
+
labels,
|
192 |
+
device=device
|
193 |
+
)
|
194 |
+
# get the first logit
|
195 |
+
query_logits = outputs.logits[:, :query_seq_len, :]
|
196 |
+
last_non_masked_idx = torch.sum(query_mask, dim=1) - 1
|
197 |
+
first_logits = query_logits[range(batch_size), last_non_masked_idx, :]
|
198 |
+
# get the second to last logit
|
199 |
+
response_logits = outputs.logits[:, query_seq_len:-1, :]
|
200 |
+
logits = torch.cat([first_logits[:, None], response_logits], dim=1)
|
201 |
+
|
202 |
+
negative_inf = HALF_NEGATIVE_INF if logits.dtype == torch.half else NEGATIVE_INF
|
203 |
+
if invalidate_eos:
|
204 |
+
logits[:, :, self.tokenizer.eos_token_id] = negative_inf # no endoftext
|
205 |
+
|
206 |
+
log_prob = F.log_softmax(logits, dim=-1)
|
207 |
+
output_logprob = torch.gather(log_prob, 2, response_input_ids[:, :, None]).squeeze(2)
|
208 |
+
output_entropy = logits_to_entropy(logits)
|
209 |
+
eos_prob = F.softmax(logits, dim=-1)[:, :, self.tokenizer.eos_token_id]
|
210 |
+
|
211 |
+
pos_logit = torch.gather(logits, 2, response_input_ids[:, :, None]).squeeze(2)
|
212 |
+
|
213 |
+
return {
|
214 |
+
'response/log_prob': mask_pad(output_logprob, response_mask),
|
215 |
+
'response/eos_prob': mask_pad(eos_prob, response_mask),
|
216 |
+
'response/entropy': mask_pad(output_entropy, response_mask),
|
217 |
+
'response/pos_logit': mask_pad(pos_logit, response_mask),
|
218 |
+
'response/logits': logits,
|
219 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
more-itertools
|
3 |
+
pyyaml==5.4
|
4 |
+
pillow
|
5 |
+
numpy
|
6 |
+
six
|
7 |
+
tqdm
|
8 |
+
ftfy
|
9 |
+
regex
|
10 |
+
huggingface-hub
|
11 |
+
ipdb
|
12 |
+
toml
|
13 |
+
torch==1.11.0
|
14 |
+
torchvision
|
15 |
+
tensorboard
|
16 |
+
transformers
|
17 |
+
clip-anytorch==2.4.0
|
18 |
+
gdown
|
19 |
+
gradio
|
run.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import platform
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import AutoModelForCausalLM
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
from numpy import asarray
|
12 |
+
import gradio as gr
|
13 |
+
import clip
|
14 |
+
|
15 |
+
from arguments import get_args
|
16 |
+
from load import load_model_args, load_model
|
17 |
+
from utils import get_first_sentence
|
18 |
+
|
19 |
+
|
20 |
+
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
21 |
+
log = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def prepare(args):
|
25 |
+
num_gpus = torch.cuda.device_count()
|
26 |
+
log.info(f'Detect {num_gpus} GPUS')
|
27 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
28 |
+
args = load_model_args(args)
|
29 |
+
|
30 |
+
def load_style(args, checkpoint):
|
31 |
+
model = AutoModelForCausalLM.from_pretrained(args.init_model)
|
32 |
+
if checkpoint is not None and Path(checkpoint).is_file():
|
33 |
+
log.info("joint model: loading pretrained style generator")
|
34 |
+
state = torch.load(checkpoint)
|
35 |
+
if 'global_step' in state:
|
36 |
+
step = state['global_step']
|
37 |
+
log.info(f'trained for {step} steps')
|
38 |
+
weights = state['state_dict']
|
39 |
+
key = 'model.'
|
40 |
+
weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)}
|
41 |
+
model.load_state_dict(weights)
|
42 |
+
else:
|
43 |
+
log.info("joint model: loading vanila gpt")
|
44 |
+
return model
|
45 |
+
|
46 |
+
log.info(f'loading models')
|
47 |
+
joint_model = load_style(args, checkpoint=getattr(args, 'demo_joint_model_weight', 'None'))
|
48 |
+
joint_model = joint_model.to(device)
|
49 |
+
model = load_model(args, device)
|
50 |
+
tokenizer = model.tokenizer
|
51 |
+
log.info(f'loaded models ')
|
52 |
+
|
53 |
+
class Inferer:
|
54 |
+
def __init__(self, args, model, joint_model, tokenizer, device):
|
55 |
+
self.args = args
|
56 |
+
self.model = model
|
57 |
+
self.joint_model = joint_model
|
58 |
+
self.tokenizer = tokenizer
|
59 |
+
self.device = device
|
60 |
+
|
61 |
+
self.clip_model, self.clip_preprocess = clip.load(args.clip_model_type, device=device, jit=False)
|
62 |
+
|
63 |
+
def infer_joint(self, batch, window_size=10, vanilla_length=20, sample=False, temperature=0.7, **kwargs):
|
64 |
+
with torch.no_grad():
|
65 |
+
rollouts = self.model.sample(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'],
|
66 |
+
features=batch['features'], labels=None,
|
67 |
+
max_len=self.args.response_length, sample=sample,
|
68 |
+
no_repeat_ngram_size=self.args.infer_no_repeat_size,
|
69 |
+
invalidate_eos=False)
|
70 |
+
'''
|
71 |
+
query = rollouts['query/input_ids']
|
72 |
+
res = rollouts['response/input_ids']
|
73 |
+
gen1 = torch.cat([query, res], dim=1)
|
74 |
+
mask1 = torch.cat([rollouts['query/mask'], rollouts['response/mask']], dim=1)
|
75 |
+
'''
|
76 |
+
res = rollouts['response/text']
|
77 |
+
query = rollouts['query/text']
|
78 |
+
generations = [f'{q} {v.strip()}' for q, v in zip(query, res)]
|
79 |
+
|
80 |
+
cur_length = self.args.response_length
|
81 |
+
if vanilla_length > 0:
|
82 |
+
for i in range(math.ceil(vanilla_length / window_size)):
|
83 |
+
cur_length += window_size
|
84 |
+
generations = self.tokenizer(generations, padding=True, return_tensors='pt').to(self.device)
|
85 |
+
context = generations['input_ids'][:, :-window_size]
|
86 |
+
inputs = generations['input_ids'][:, -window_size:]
|
87 |
+
out = self.joint_model.generate(input_ids=inputs,
|
88 |
+
max_length=cur_length, sample=sample,
|
89 |
+
no_repeat_ngram_size=self.args.infer_no_repeat_size,
|
90 |
+
pad_token_id=self.tokenizer.eos_token_id)
|
91 |
+
out = torch.cat([context, out], dim=1)
|
92 |
+
text = [self.tokenizer.decode(v, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
93 |
+
for v in out]
|
94 |
+
# generations = [get_first_sentence(v) for v in generations]
|
95 |
+
generations = text
|
96 |
+
query = rollouts['query/text']
|
97 |
+
del rollouts
|
98 |
+
torch.cuda.empty_cache()
|
99 |
+
return query, generations
|
100 |
+
|
101 |
+
def get_feature(self, image):
|
102 |
+
image = self.clip_preprocess(image).unsqueeze(0).to(self.device)
|
103 |
+
feature = self.clip_model.encode_image(image)
|
104 |
+
return feature
|
105 |
+
|
106 |
+
def __call__(self, image, prompt, length=20, window_size=20, **kwargs):
|
107 |
+
window_size = min(window_size, length)
|
108 |
+
vanilla_length = max(0, length - self.args.response_length)
|
109 |
+
if not prompt:
|
110 |
+
prompt = 'The'
|
111 |
+
feature = self.get_feature(image)
|
112 |
+
feature = feature.unsqueeze(0).to(self.device)
|
113 |
+
batch = self.tokenizer(prompt, padding=True, return_tensors='pt').to(self.device)
|
114 |
+
batch['features'] = feature
|
115 |
+
query, generations = self.infer_joint(batch, window_size=window_size,
|
116 |
+
vanilla_length=vanilla_length, **kwargs)
|
117 |
+
# text = f'{query[0].strip()} {generations[0].strip()}'
|
118 |
+
text = generations[0].strip()
|
119 |
+
return text
|
120 |
+
|
121 |
+
inferer = Inferer(args, model, joint_model, tokenizer, device)
|
122 |
+
return inferer
|
123 |
+
|
124 |
+
|
125 |
+
class Runner:
|
126 |
+
def __init__(self, inferer):
|
127 |
+
self.inferer = inferer
|
128 |
+
|
129 |
+
def __call__(self, inp, prompt, length, window_size, sample):
|
130 |
+
# inp = inp.reshape((224, 224, 3))
|
131 |
+
img = Image.fromarray(np.uint8(inp))
|
132 |
+
text = self.inferer(img, prompt, length, window_size, sample=sample)
|
133 |
+
return prompt, text
|
134 |
+
# return inp, prompt, text
|
135 |
+
|
136 |
+
|
137 |
+
'''
|
138 |
+
# test_run
|
139 |
+
sample_img = asarray(Image.open('../data/coco/images/sample.jpg'))
|
140 |
+
img, _, text = run(sample_img, 'There lies', 50, 20, sample=False)
|
141 |
+
print('test_run:', text)
|
142 |
+
'''
|
143 |
+
|
144 |
+
def launch(examples=None):
|
145 |
+
args = get_args()
|
146 |
+
inferer = prepare(args)
|
147 |
+
runner = Runner(inferer)
|
148 |
+
|
149 |
+
iface = gr.Interface(
|
150 |
+
title="Demo for ESPER",
|
151 |
+
fn=runner.__call__,
|
152 |
+
inputs=[gr.components.Image(shape=(224, 224)),
|
153 |
+
gr.components.Textbox(label='prompt'),
|
154 |
+
gr.components.Slider(20, 120, step=1, label='length'),
|
155 |
+
gr.components.Slider(10, 100, step=1, label='window_size'),
|
156 |
+
gr.components.Checkbox(label='do sample')],
|
157 |
+
outputs=[gr.components.Textbox(label='prompt'),
|
158 |
+
gr.components.Textbox(label='generation')],
|
159 |
+
examples=examples
|
160 |
+
)
|
161 |
+
if args.port is not None:
|
162 |
+
print(f"running from {platform.node()}")
|
163 |
+
iface.launch(
|
164 |
+
server_name="0.0.0.0",
|
165 |
+
server_port=args.port
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
iface.launch()
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
print(f"running from {platform.node()}")
|
173 |
+
launch()
|
utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
NEGATIVE_INF = -100000.0
|
5 |
+
HALF_NEGATIVE_INF = -60000.0 # half precision
|
6 |
+
|
7 |
+
|
8 |
+
def get_first_sentence(txt, min_len=5):
|
9 |
+
eos = '<|endoftext|>'
|
10 |
+
eos_idx = txt.find(eos)
|
11 |
+
if eos_idx > 0:
|
12 |
+
txt = txt[eos_idx:]
|
13 |
+
txt = txt.replace('\n', ' ')
|
14 |
+
sents = txt.split('. ')
|
15 |
+
if len(sents[0]) >= min_len:
|
16 |
+
sent = f'{sents[0].strip()}.'
|
17 |
+
else:
|
18 |
+
sent = txt
|
19 |
+
return sent
|
20 |
+
|
21 |
+
|
22 |
+
def logits_to_entropy(logits):
|
23 |
+
distribution = torch.distributions.Categorical(logits=logits)
|
24 |
+
return distribution.entropy()
|
25 |
+
|
26 |
+
|
27 |
+
def mask_pad(value, mask):
|
28 |
+
return value * mask + NEGATIVE_INF * (1 - mask)
|