|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from logging import getLogger |
|
import math |
|
import os |
|
from typing import Dict, List, Optional, Union, Tuple |
|
from types import MethodType |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.nn.utils import parametrize |
|
from torch.nn.utils.parametrizations import _SpectralNorm |
|
|
|
from timm.models.vision_transformer import Attention, Mlp |
|
|
|
_EPS = 1e-5 |
|
|
|
|
|
class _SNReweight(_SpectralNorm): |
|
def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs): |
|
super().__init__(weight, *args, **kwargs) |
|
|
|
self.alpha = alpha |
|
self.version = version |
|
self.register_buffer('_sn_version', torch.tensor(version)) |
|
|
|
if init_norm_to_current: |
|
|
|
init_scale = self._get_sigma(weight, n_power_iterations=20).item() |
|
else: |
|
init_scale = 1.0 |
|
|
|
if version == 1: |
|
init_value = init_scale |
|
elif version == 2: |
|
t = init_scale - alpha |
|
if t < _EPS: |
|
getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.') |
|
t = _EPS |
|
|
|
init_value = math.log(math.exp(t) - 1) |
|
else: |
|
raise ValueError(f'Unsupported version: {version}') |
|
|
|
|
|
self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device)) |
|
|
|
|
|
def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = None) -> torch.Tensor: |
|
if not n_power_iterations: |
|
n_power_iterations = self.n_power_iterations |
|
if weight.ndim == 1: |
|
|
|
sigma = weight.norm() |
|
else: |
|
weight_mat = self._reshape_weight_to_matrix(weight) |
|
if self.training: |
|
self._power_method(weight_mat, n_power_iterations) |
|
|
|
u = self._u.clone(memory_format=torch.contiguous_format) |
|
v = self._v.clone(memory_format=torch.contiguous_format) |
|
|
|
|
|
|
|
sigma = torch.dot(u, torch.mv(weight_mat, v)) |
|
|
|
return sigma + self.eps |
|
|
|
def forward(self, weight: torch.Tensor, *args, **kwargs): |
|
dtype = weight.dtype |
|
sigma = self._get_sigma(weight, *args, **kwargs) |
|
|
|
if self.version == 1: |
|
scale = self.scale |
|
elif self.version == 2: |
|
scale = F.softplus(self.scale) + self.alpha |
|
else: |
|
raise ValueError(f'Unsupported version: {self.version}') |
|
|
|
scale = scale.float() / sigma.float() |
|
|
|
y = weight * scale |
|
|
|
if dtype in (torch.float16, torch.bfloat16): |
|
y = y.to(dtype) |
|
return y |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
|
version_key = f'{prefix}_sn_version' |
|
if version_key not in state_dict: |
|
self.version = 1 |
|
state_dict[version_key] = torch.tensor(1) |
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
|
|
|
|
|
class _ChunkedSNReweight(nn.Module): |
|
def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_norm_to_current: bool = False, **kwargs): |
|
super().__init__() |
|
|
|
self.num_chunks = num_chunks |
|
parts = weight.split(weight.shape[0] // num_chunks, dim=0) |
|
|
|
self.parts = nn.ModuleList([ |
|
_SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs) |
|
for p in parts |
|
]) |
|
|
|
def forward(self, weight: torch.Tensor, *args, **kwargs): |
|
parts = weight.split(weight.shape[0] // self.num_chunks, dim=0) |
|
|
|
parts = [ |
|
fn(p) |
|
for fn, p in zip(self.parts, parts) |
|
] |
|
|
|
return torch.cat(parts, dim=0) |
|
|
|
|
|
class _AttnSNReweight(_ChunkedSNReweight): |
|
def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs): |
|
super().__init__(weight, 3, *args, init_norm_to_current=init_norm_to_current, **kwargs) |
|
|
|
if not renorm_values: |
|
self.parts[2] = nn.Identity() |
|
|
|
|
|
def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]], |
|
n_power_iterations: int = 1, |
|
eps: float = 1e-6, |
|
init_norm_to_current: bool = False, |
|
renorm_values: bool = True, |
|
renorm_mlp: bool = True, |
|
state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None): |
|
if isinstance(model, (list, tuple)): |
|
for i, sub in enumerate(model): |
|
sub_sd = state_dict_guidance[i] if isinstance(state_dict_guidance, (list, tuple)) else state_dict_guidance |
|
enable_spectral_reparam(sub, n_power_iterations=n_power_iterations, eps=eps, |
|
init_norm_to_current=init_norm_to_current, renorm_values=renorm_values, |
|
renorm_mlp=renorm_mlp, state_dict_guidance=sub_sd) |
|
return |
|
|
|
print('Enabling spectral reparametrization') |
|
args = dict(n_power_iterations=n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current) |
|
visited_prefixes = set() |
|
|
|
def is_guidance_parametrized(name: str): |
|
if state_dict_guidance is None: |
|
return True |
|
|
|
p_name = f'{name}.parametrizations' |
|
is_prm = any(k for k in state_dict_guidance if k.startswith(p_name)) |
|
return is_prm |
|
|
|
def parametrize_linear(linear: nn.Linear): |
|
parametrize.register_parametrization( |
|
linear, |
|
'weight', |
|
_SNReweight(linear.weight, **args) |
|
) |
|
|
|
for name, mod in model.named_modules(): |
|
pref = '.'.join(name.split('.')[:-1]) |
|
if pref in visited_prefixes: |
|
continue |
|
|
|
if isinstance(mod, Attention) or name.endswith('.attn'): |
|
if is_guidance_parametrized(f'{name}.qkv'): |
|
parametrize.register_parametrization( |
|
mod.qkv, |
|
'weight', |
|
_AttnSNReweight(mod.qkv.weight, renorm_values=renorm_values, **args), |
|
) |
|
if hasattr(mod, 'proj') and is_guidance_parametrized(f'{name}.proj'): |
|
parametrize_linear(mod.proj) |
|
visited_prefixes.add(name) |
|
elif name.endswith('mlp') and renorm_mlp and hasattr(mod, 'w12'): |
|
if is_guidance_parametrized(f'{name}.w12'): |
|
parametrize.register_parametrization( |
|
mod.w12, |
|
'weight', |
|
_ChunkedSNReweight(mod.w12.weight, num_chunks=2, **args), |
|
) |
|
if is_guidance_parametrized(f'{name}.w3'): |
|
parametrize_linear(mod.w3) |
|
visited_prefixes.add(name) |
|
elif isinstance(mod, nn.Linear) and 'patch_generator' not in name and is_guidance_parametrized(name): |
|
parametrize_linear(mod) |
|
|
|
|
|
def configure_spectral_reparam_from_args(model: nn.Module, args, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None): |
|
spectral_reparam = getattr(args, 'spectral_reparam', False) |
|
if isinstance(spectral_reparam, bool) and spectral_reparam: |
|
enable_spectral_reparam(model, init_norm_to_current=True, state_dict_guidance=state_dict_guidance) |
|
elif isinstance(spectral_reparam, dict): |
|
enable_spectral_reparam( |
|
model, |
|
n_power_iterations=spectral_reparam.get('n_power_iterations', 1), |
|
eps=spectral_reparam.get('eps', 1e-12), |
|
init_norm_to_current=True, |
|
state_dict_guidance=state_dict_guidance, |
|
) |
|
|
|
|
|
def disable_spectral_reparam(model: nn.Module): |
|
print('Disabling spectral reparametrization') |
|
for name, mod in model.named_modules(): |
|
if parametrize.is_parametrized(mod): |
|
parametrize.remove_parametrizations(mod, 'weight') |
|
pass |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
from . import radio_model as create_model |
|
|
|
parser = argparse.ArgumentParser(description='Remove parametrization from state dict') |
|
parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load') |
|
parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint') |
|
parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields') |
|
parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict') |
|
|
|
args = parser.parse_args() |
|
|
|
if not args.output: |
|
chk_dir, chk_name = os.path.split(args.checkpoint) |
|
args.output = os.path.join(chk_dir, f'clean_{chk_name}') |
|
print(f'Set output to "{args.output}"') |
|
|
|
chk = torch.load(args.checkpoint, map_location='cpu', mmap=True) |
|
|
|
model = create_model.create_model_from_args(chk['args']) |
|
|
|
key = 'base_model.' |
|
mod_state = dict() |
|
extra_state = dict() |
|
for k, v in chk['state_dict'].items(): |
|
if k.startswith(key): |
|
mod_state[k[len(key):]] = v |
|
else: |
|
extra_state[k] = v |
|
|
|
chk_load_info = model.load_state_dict(mod_state, strict=args.strict) |
|
if chk_load_info.unexpected_keys or chk_load_info.missing_keys: |
|
print(chk_load_info) |
|
|
|
if chk['args'].spectral_reparam: |
|
disable_spectral_reparam(model) |
|
|
|
if hasattr(chk['args'], 'dtype'): |
|
model.to(dtype=chk['args'].dtype) |
|
|
|
mod_state = model.state_dict() |
|
final_state = dict() |
|
final_state.update({f'{key}{k}': v for k, v in mod_state.items()}) |
|
final_state.update(extra_state) |
|
|
|
chk['state_dict'] = final_state |
|
chk['args'].spectral_reparam = False |
|
|
|
if args.release: |
|
chk = { |
|
'arch': chk['arch'], |
|
'epoch': chk['epoch'], |
|
'state_dict': chk['state_dict'], |
|
'args': chk['args'], |
|
} |
|
|
|
torch.save(chk, args.output) |
|
pass |
|
|