Spaces:
Runtime error
Runtime error
Fabrice-TIERCELIN
commited on
Delete clipseg/models/vitseg.py
Browse files- clipseg/models/vitseg.py +0 -286
clipseg/models/vitseg.py
DELETED
@@ -1,286 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from posixpath import basename, dirname, join
|
3 |
-
# import clip
|
4 |
-
from clip.model import convert_weights
|
5 |
-
import torch
|
6 |
-
import json
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import functional as nnf
|
9 |
-
from torch.nn.modules import activation
|
10 |
-
from torch.nn.modules.activation import ReLU
|
11 |
-
from torchvision import transforms
|
12 |
-
|
13 |
-
normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
14 |
-
|
15 |
-
from torchvision.models import ResNet
|
16 |
-
|
17 |
-
|
18 |
-
def process_prompts(conditional, prompt_list, conditional_map):
|
19 |
-
# DEPRECATED
|
20 |
-
|
21 |
-
# randomly sample a synonym
|
22 |
-
words = [conditional_map[int(i)] for i in conditional]
|
23 |
-
words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
|
24 |
-
words = [w.replace('_', ' ') for w in words]
|
25 |
-
|
26 |
-
if prompt_list is not None:
|
27 |
-
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
28 |
-
prompts = [prompt_list[i] for i in prompt_indices]
|
29 |
-
else:
|
30 |
-
prompts = ['a photo of {}'] * (len(words))
|
31 |
-
|
32 |
-
return [promt.format(w) for promt, w in zip(prompts, words)]
|
33 |
-
|
34 |
-
|
35 |
-
class VITDenseBase(nn.Module):
|
36 |
-
|
37 |
-
def rescaled_pos_emb(self, new_size):
|
38 |
-
assert len(new_size) == 2
|
39 |
-
|
40 |
-
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
|
41 |
-
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
|
42 |
-
return torch.cat([self.model.positional_embedding[:1], b])
|
43 |
-
|
44 |
-
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
|
45 |
-
|
46 |
-
with torch.no_grad():
|
47 |
-
|
48 |
-
x_inp = nnf.interpolate(x_inp, (384, 384))
|
49 |
-
|
50 |
-
x = self.model.patch_embed(x_inp)
|
51 |
-
cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
52 |
-
if self.model.dist_token is None:
|
53 |
-
x = torch.cat((cls_token, x), dim=1)
|
54 |
-
else:
|
55 |
-
x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
56 |
-
x = self.model.pos_drop(x + self.model.pos_embed)
|
57 |
-
|
58 |
-
activations = []
|
59 |
-
for i, block in enumerate(self.model.blocks):
|
60 |
-
x = block(x)
|
61 |
-
|
62 |
-
if i in extract_layers:
|
63 |
-
# permute to be compatible with CLIP
|
64 |
-
activations += [x.permute(1,0,2)]
|
65 |
-
|
66 |
-
x = self.model.norm(x)
|
67 |
-
x = self.model.head(self.model.pre_logits(x[:, 0]))
|
68 |
-
|
69 |
-
# again for CLIP compatibility
|
70 |
-
# x = x.permute(1, 0, 2)
|
71 |
-
|
72 |
-
return x, activations, None
|
73 |
-
|
74 |
-
def sample_prompts(self, words, prompt_list=None):
|
75 |
-
|
76 |
-
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
|
77 |
-
|
78 |
-
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
79 |
-
prompts = [prompt_list[i] for i in prompt_indices]
|
80 |
-
return [promt.format(w) for promt, w in zip(prompts, words)]
|
81 |
-
|
82 |
-
def get_cond_vec(self, conditional, batch_size):
|
83 |
-
# compute conditional from a single string
|
84 |
-
if conditional is not None and type(conditional) == str:
|
85 |
-
cond = self.compute_conditional(conditional)
|
86 |
-
cond = cond.repeat(batch_size, 1)
|
87 |
-
|
88 |
-
# compute conditional from string list/tuple
|
89 |
-
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
|
90 |
-
assert len(conditional) == batch_size
|
91 |
-
cond = self.compute_conditional(conditional)
|
92 |
-
|
93 |
-
# use conditional directly
|
94 |
-
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
|
95 |
-
cond = conditional
|
96 |
-
|
97 |
-
# compute conditional from image
|
98 |
-
elif conditional is not None and type(conditional) == torch.Tensor:
|
99 |
-
with torch.no_grad():
|
100 |
-
cond, _, _ = self.visual_forward(conditional)
|
101 |
-
else:
|
102 |
-
raise ValueError('invalid conditional')
|
103 |
-
return cond
|
104 |
-
|
105 |
-
def compute_conditional(self, conditional):
|
106 |
-
import clip
|
107 |
-
|
108 |
-
dev = next(self.parameters()).device
|
109 |
-
|
110 |
-
if type(conditional) in {list, tuple}:
|
111 |
-
text_tokens = clip.tokenize(conditional).to(dev)
|
112 |
-
cond = self.clip_model.encode_text(text_tokens)
|
113 |
-
else:
|
114 |
-
if conditional in self.precomputed_prompts:
|
115 |
-
cond = self.precomputed_prompts[conditional].float().to(dev)
|
116 |
-
else:
|
117 |
-
text_tokens = clip.tokenize([conditional]).to(dev)
|
118 |
-
cond = self.clip_model.encode_text(text_tokens)[0]
|
119 |
-
|
120 |
-
return cond
|
121 |
-
|
122 |
-
|
123 |
-
class VITDensePredT(VITDenseBase):
|
124 |
-
|
125 |
-
def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
|
126 |
-
depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
|
127 |
-
learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
|
128 |
-
add_calibration=False, process_cond=None, not_pretrained=False):
|
129 |
-
super().__init__()
|
130 |
-
# device = 'cpu'
|
131 |
-
|
132 |
-
self.extract_layers = extract_layers
|
133 |
-
self.cond_layer = cond_layer
|
134 |
-
self.limit_to_clip_only = limit_to_clip_only
|
135 |
-
self.process_cond = None
|
136 |
-
|
137 |
-
if add_calibration:
|
138 |
-
self.calibration_conds = 1
|
139 |
-
|
140 |
-
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
|
141 |
-
|
142 |
-
self.add_activation1 = True
|
143 |
-
|
144 |
-
import timm
|
145 |
-
self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
|
146 |
-
self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
|
147 |
-
|
148 |
-
for p in self.model.parameters():
|
149 |
-
p.requires_grad_(False)
|
150 |
-
|
151 |
-
import clip
|
152 |
-
self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
|
153 |
-
# del self.clip_model.visual
|
154 |
-
|
155 |
-
|
156 |
-
self.token_shape = (14, 14)
|
157 |
-
|
158 |
-
# conditional
|
159 |
-
if reduce_cond is not None:
|
160 |
-
self.reduce_cond = nn.Linear(512, reduce_cond)
|
161 |
-
for p in self.reduce_cond.parameters():
|
162 |
-
p.requires_grad_(False)
|
163 |
-
else:
|
164 |
-
self.reduce_cond = None
|
165 |
-
|
166 |
-
# self.film = AVAILABLE_BLOCKS['film'](512, 128)
|
167 |
-
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
168 |
-
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
169 |
-
|
170 |
-
# DEPRECATED
|
171 |
-
# self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
|
172 |
-
|
173 |
-
assert len(self.extract_layers) == depth
|
174 |
-
|
175 |
-
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
|
176 |
-
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
|
177 |
-
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
|
178 |
-
|
179 |
-
trans_conv_ks = (16, 16)
|
180 |
-
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
181 |
-
|
182 |
-
# refinement and trans conv
|
183 |
-
|
184 |
-
if learn_trans_conv_only:
|
185 |
-
for p in self.parameters():
|
186 |
-
p.requires_grad_(False)
|
187 |
-
|
188 |
-
for p in self.trans_conv.parameters():
|
189 |
-
p.requires_grad_(True)
|
190 |
-
|
191 |
-
if prompt == 'fixed':
|
192 |
-
self.prompt_list = ['a photo of a {}.']
|
193 |
-
elif prompt == 'shuffle':
|
194 |
-
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
195 |
-
elif prompt == 'shuffle+':
|
196 |
-
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
|
197 |
-
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
|
198 |
-
'a bad photo of a {}.', 'a photo of the {}.']
|
199 |
-
elif prompt == 'shuffle_clip':
|
200 |
-
from models.clip_prompts import imagenet_templates
|
201 |
-
self.prompt_list = imagenet_templates
|
202 |
-
|
203 |
-
if process_cond is not None:
|
204 |
-
if process_cond == 'clamp' or process_cond[0] == 'clamp':
|
205 |
-
|
206 |
-
val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
|
207 |
-
|
208 |
-
def clamp_vec(x):
|
209 |
-
return torch.clamp(x, -val, val)
|
210 |
-
|
211 |
-
self.process_cond = clamp_vec
|
212 |
-
|
213 |
-
elif process_cond.endswith('.pth'):
|
214 |
-
|
215 |
-
shift = torch.load(process_cond)
|
216 |
-
def add_shift(x):
|
217 |
-
return x + shift.to(x.device)
|
218 |
-
|
219 |
-
self.process_cond = add_shift
|
220 |
-
|
221 |
-
import pickle
|
222 |
-
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
|
223 |
-
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
|
224 |
-
|
225 |
-
|
226 |
-
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
|
227 |
-
|
228 |
-
assert type(return_features) == bool
|
229 |
-
|
230 |
-
# inp_image = inp_image.to(self.model.positional_embedding.device)
|
231 |
-
|
232 |
-
if mask is not None:
|
233 |
-
raise ValueError('mask not supported')
|
234 |
-
|
235 |
-
# x_inp = normalize(inp_image)
|
236 |
-
x_inp = inp_image
|
237 |
-
|
238 |
-
bs, dev = inp_image.shape[0], x_inp.device
|
239 |
-
|
240 |
-
inp_image_size = inp_image.shape[2:]
|
241 |
-
|
242 |
-
cond = self.get_cond_vec(conditional, bs)
|
243 |
-
|
244 |
-
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
|
245 |
-
|
246 |
-
activation1 = activations[0]
|
247 |
-
activations = activations[1:]
|
248 |
-
|
249 |
-
a = None
|
250 |
-
for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
|
251 |
-
|
252 |
-
if a is not None:
|
253 |
-
a = reduce(activation) + a
|
254 |
-
else:
|
255 |
-
a = reduce(activation)
|
256 |
-
|
257 |
-
if i == self.cond_layer:
|
258 |
-
if self.reduce_cond is not None:
|
259 |
-
cond = self.reduce_cond(cond)
|
260 |
-
|
261 |
-
a = self.film_mul(cond) * a + self.film_add(cond)
|
262 |
-
|
263 |
-
a = block(a)
|
264 |
-
|
265 |
-
for block in self.extra_blocks:
|
266 |
-
a = a + block(a)
|
267 |
-
|
268 |
-
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
269 |
-
|
270 |
-
size = int(math.sqrt(a.shape[2]))
|
271 |
-
|
272 |
-
a = a.view(bs, a.shape[1], size, size)
|
273 |
-
|
274 |
-
if self.trans_conv is not None:
|
275 |
-
a = self.trans_conv(a)
|
276 |
-
|
277 |
-
if self.upsample_proj is not None:
|
278 |
-
a = self.upsample_proj(a)
|
279 |
-
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
|
280 |
-
|
281 |
-
a = nnf.interpolate(a, inp_image_size)
|
282 |
-
|
283 |
-
if return_features:
|
284 |
-
return a, visual_q, cond, [activation1] + activations
|
285 |
-
else:
|
286 |
-
return a,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|