gheinrich commited on
Commit
91daee9
1 Parent(s): 6871353

Upload model

Browse files
enable_spectral_reparam.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ import math
3
+ import os
4
+ from typing import Union, Tuple
5
+ from types import MethodType
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.utils import parametrize
11
+ from torch.nn.utils.parametrizations import _SpectralNorm
12
+
13
+ from timm.models.vision_transformer import Attention, Mlp
14
+
15
+ _EPS = 1e-5
16
+
17
+
18
+ class _SNReweight(_SpectralNorm):
19
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs):
20
+ super().__init__(weight, *args, **kwargs)
21
+
22
+ self.alpha = alpha
23
+ self.version = version
24
+ self.register_buffer('_sn_version', torch.tensor(version))
25
+
26
+ if init_norm_to_current:
27
+ # This will set the numerator to match the denominator, which should preserve the original values
28
+ init_scale = self._get_sigma(weight).item()
29
+ else:
30
+ init_scale = 1.0
31
+
32
+ if version == 1:
33
+ init_value = init_scale
34
+ elif version == 2:
35
+ t = init_scale - alpha
36
+ if t < _EPS:
37
+ getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.')
38
+ t = _EPS
39
+
40
+ init_value = math.log(math.exp(t) - 1)
41
+ else:
42
+ raise ValueError(f'Unsupported version: {version}')
43
+
44
+ # Make 2D so that weight decay gets applied
45
+ self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device))
46
+
47
+ # Re-implementing this because we need to make division by sigma safe
48
+ def _get_sigma(self, weight: torch.Tensor) -> torch.Tensor:
49
+ if weight.ndim == 1:
50
+ # Faster and more exact path, no need to approximate anything
51
+ sigma = weight.norm()
52
+ else:
53
+ weight_mat = self._reshape_weight_to_matrix(weight)
54
+ if self.training:
55
+ self._power_method(weight_mat, self.n_power_iterations)
56
+ # See above on why we need to clone
57
+ u = self._u.clone(memory_format=torch.contiguous_format)
58
+ v = self._v.clone(memory_format=torch.contiguous_format)
59
+ # The proper way of computing this should be through F.bilinear, but
60
+ # it seems to have some efficiency issues:
61
+ # https://github.com/pytorch/pytorch/issues/58093
62
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
63
+
64
+ return sigma + self.eps
65
+
66
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
67
+ dtype = weight.dtype
68
+ sigma = self._get_sigma(weight, *args, **kwargs)
69
+
70
+ if self.version == 1:
71
+ scale = self.scale
72
+ elif self.version == 2:
73
+ scale = F.softplus(self.scale) + self.alpha
74
+ else:
75
+ raise ValueError(f'Unsupported version: {self.version}')
76
+
77
+ scale = scale.float() / sigma.float()
78
+
79
+ y = weight * scale
80
+
81
+ if dtype in (torch.float16, torch.bfloat16):
82
+ y = y.to(dtype)
83
+ return y
84
+
85
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
86
+ version_key = f'{prefix}_sn_version'
87
+ if version_key not in state_dict:
88
+ self.version = 1
89
+ state_dict[version_key] = torch.tensor(1)
90
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
91
+
92
+
93
+ class _AttnSNReweight(nn.Module):
94
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs):
95
+ super().__init__()
96
+
97
+ parts = weight.split(weight.shape[0] // 3, dim=0)
98
+
99
+ ct = 2 if not renorm_values else 3
100
+
101
+ self.parts = nn.ModuleList([
102
+ _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs) if i < ct else nn.Identity()
103
+ for i, p in enumerate(parts)
104
+ ])
105
+
106
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
107
+ parts = weight.split(weight.shape[0] // 3, dim=0)
108
+
109
+ parts = [
110
+ fn(p)
111
+ for fn, p in zip(self.parts, parts)
112
+ ]
113
+
114
+ return torch.cat(parts, dim=0)
115
+
116
+
117
+ def enable_spectral_reparam(model: nn.Module,
118
+ n_power_iterations: int = 1,
119
+ eps: float = 1e-6,
120
+ init_norm_to_current: bool = False,
121
+ renorm_values: bool = True,
122
+ renorm_mlp: bool = True):
123
+ # print('Enabling spectral reparametrization')
124
+ for mod in model.modules():
125
+ if isinstance(mod, Attention):
126
+ parametrize.register_parametrization(
127
+ mod.qkv,
128
+ 'weight',
129
+ _AttnSNReweight(mod.qkv.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current, renorm_values=renorm_values),
130
+ )
131
+ pass
132
+ elif isinstance(mod, Mlp) and renorm_mlp:
133
+ parametrize.register_parametrization(
134
+ mod.fc1,
135
+ 'weight',
136
+ _SNReweight(mod.fc1.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current),
137
+ )
138
+ parametrize.register_parametrization(
139
+ mod.fc2,
140
+ 'weight',
141
+ _SNReweight(mod.fc2.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current),
142
+ )
143
+ pass
144
+
145
+
146
+ def configure_spectral_reparam_from_args(model: nn.Module, args):
147
+ spectral_reparam = getattr(args, 'spectral_reparam', False)
148
+ if isinstance(spectral_reparam, bool) and spectral_reparam:
149
+ enable_spectral_reparam(model, init_norm_to_current=args.pretrained)
150
+ elif isinstance(spectral_reparam, dict):
151
+ enable_spectral_reparam(
152
+ model,
153
+ n_power_iterations=spectral_reparam.get('n_power_iterations', 1),
154
+ eps=spectral_reparam.get('eps', 1e-12),
155
+ init_norm_to_current=args.pretrained,
156
+ )
157
+
158
+
159
+ def disable_spectral_reparam(model: nn.Module):
160
+ for mod in model.modules():
161
+ if isinstance(mod, Attention):
162
+ parametrize.remove_parametrizations(mod.qkv, 'weight')
163
+ pass
164
+ elif isinstance(mod, Mlp):
165
+ parametrize.remove_parametrizations(mod.fc1, 'weight')
166
+ parametrize.remove_parametrizations(mod.fc2, 'weight')
167
+ pass
168
+
169
+
170
+ if __name__ == '__main__':
171
+ import argparse
172
+ from . import radio_model as create_model
173
+
174
+ parser = argparse.ArgumentParser(description='Remove parametrization from state dict')
175
+ parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load')
176
+ parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint')
177
+ parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields')
178
+ parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict')
179
+
180
+ args = parser.parse_args()
181
+
182
+ if not args.output:
183
+ chk_dir, chk_name = os.path.split(args.checkpoint)
184
+ args.output = os.path.join(chk_dir, f'clean_{chk_name}')
185
+ print(f'Set output to "{args.output}"')
186
+
187
+ chk = torch.load(args.checkpoint, map_location='cpu', mmap=True)
188
+
189
+ model = create_model.create_model_from_args(chk['args'])
190
+
191
+ key = 'base_model.'
192
+ mod_state = dict()
193
+ extra_state = dict()
194
+ for k, v in chk['state_dict'].items():
195
+ if k.startswith(key):
196
+ mod_state[k[len(key):]] = v
197
+ else:
198
+ extra_state[k] = v
199
+
200
+ chk_load_info = model.load_state_dict(mod_state, strict=args.strict)
201
+ if chk_load_info.unexpected_keys or chk_load_info.missing_keys:
202
+ print(chk_load_info)
203
+
204
+ if chk['args'].spectral_reparam:
205
+ disable_spectral_reparam(model)
206
+
207
+ if hasattr(chk['args'], 'dtype'):
208
+ model.to(dtype=chk['args'].dtype)
209
+
210
+ mod_state = model.state_dict()
211
+ final_state = dict()
212
+ final_state.update({f'{key}{k}': v for k, v in mod_state.items()})
213
+ final_state.update(extra_state)
214
+
215
+ chk['state_dict'] = final_state
216
+ chk['args'].spectral_reparam = False
217
+
218
+ if args.release:
219
+ chk = {
220
+ 'arch': chk['arch'],
221
+ 'epoch': chk['epoch'],
222
+ 'state_dict': chk['state_dict'],
223
+ 'args': chk['args'],
224
+ }
225
+
226
+ torch.save(chk, args.output)
227
+ pass
eradio_model.py CHANGED
@@ -1162,6 +1162,9 @@ class FasterViT(nn.Module):
1162
  return {'rpb'}
1163
 
1164
  def forward_features(self, x):
 
 
 
1165
  x = self.patch_embed(x)
1166
  full_features = None
1167
  for il, level in enumerate(self.levels):
 
1162
  return {'rpb'}
1163
 
1164
  def forward_features(self, x):
1165
+ _, _, H, W = x.shape
1166
+ if H % 32 != 0 or W % 32 != 0:
1167
+ raise ValueError(f"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}")
1168
  x = self.patch_embed(x)
1169
  full_features = None
1170
  for il, level in enumerate(self.levels):
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9079d79a8948849416e84a25d9318e020e719dbe6f8c16a13d674f8e1f5e6b88
3
- size 1614710336
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32b9c1a630de485378c235a695e7bf735e4f29657ccedb7daeb1dbc4aed73f4c
3
+ size 1608537016
radio_model.py CHANGED
@@ -18,6 +18,7 @@ from .input_conditioner import InputConditioner
18
  from . import extra_timm_models
19
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
20
  from . import eradio_model
 
21
 
22
 
23
  class Resolution(NamedTuple):
@@ -180,6 +181,11 @@ def create_model_from_args(args) -> nn.Module:
180
  **args.model_kwargs,
181
  )
182
 
 
 
 
 
 
183
  assert (
184
  not args.cls_token_per_teacher or args.cpe_max_size is not None
185
  ), "CPE must be enabled for multiple CLS tokens!"
@@ -192,4 +198,7 @@ def create_model_from_args(args) -> nn.Module:
192
  register_multiple=args.register_multiple,
193
  )
194
 
 
 
 
195
  return model
 
18
  from . import extra_timm_models
19
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
20
  from . import eradio_model
21
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
22
 
23
 
24
  class Resolution(NamedTuple):
 
181
  **args.model_kwargs,
182
  )
183
 
184
+ if hasattr(model, 'norm') and not getattr(args, 'model_norm', False):
185
+ model.norm = nn.Identity()
186
+
187
+ model.head = nn.Identity()
188
+
189
  assert (
190
  not args.cls_token_per_teacher or args.cpe_max_size is not None
191
  ), "CPE must be enabled for multiple CLS tokens!"
 
198
  register_multiple=args.register_multiple,
199
  )
200
 
201
+ if args.spectral_reparam:
202
+ configure_spectral_reparam_from_args(model, args)
203
+
204
  return model