File size: 15,936 Bytes
c664d09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 |
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
import math
import warnings
from collections.abc import Sequence
from functools import partial
from typing import Optional, Tuple, Union
import torch
from torch import nn
def torch_default_param_init_fn_(
module: nn.Module,
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
if verbose > 1:
warnings.warn(
f"Initializing network using module's reset_parameters attribute")
if hasattr(module, 'reset_parameters'):
module.reset_parameters() # type: ignore
def fused_init_helper_(module: nn.Module, init_fn_):
# parameter initialization is often based on the parameters shape.
# If a layer is fused, initialization should be based on the shapes
# of the original tensor instead of the shape of the fused tensor.
# Layers which are fused should have the _fused attibute defined.
# The first element of _fused is the dimension along which the tensor is fused.
# This is followed by an iterable of split indices."
_fused = getattr(module, '_fused', None)
if _fused is None:
raise RuntimeError(f'Internal logic error')
dim, splits = _fused
splits = (0, *splits, module.weight.size(dim)) # type: ignore
for s, e in zip(splits[:-1], splits[1:]):
slice_indices = [slice(None)] * module.weight.ndim # type: ignore
slice_indices[dim] = slice(s, e)
init_fn_(module.weight[slice_indices]) # type: ignore
def generic_param_init_fn_(
module: nn.Module,
init_fn_,
n_layers: int,
d_model: Optional[int] = None,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
if verbose > 1:
warnings.warn(
f'If model has bias parameters they are initialized to 0.')
# enable user to divide _is_residual weights by
# a value which defaults to math.sqrt(2 * cfg.n_layers)
init_div_is_residual = init_div_is_residual
if init_div_is_residual is False:
# not used, for pyright
div_is_residual = 1.0
elif init_div_is_residual is True:
div_is_residual = math.sqrt(2 * n_layers)
elif isinstance(init_div_is_residual, float) or isinstance(
init_div_is_residual, int):
div_is_residual = init_div_is_residual
elif isinstance(init_div_is_residual,
str) and init_div_is_residual.isnumeric():
# do not trust YAML parsing to always convert numbers to numbers
div_is_residual = float(init_div_is_residual)
else:
# not used, for pyright
div_is_residual = 1.0
raise ValueError(
f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
)
if init_div_is_residual is not False:
if verbose > 1:
warnings.warn(
f'Initializing _is_residual layers then dividing them by {div_is_residual}.' +\
f'set `init_div_is_residual: false` in model config to disable this.'
)
if isinstance(module, nn.Linear):
# Linear
if hasattr(module, '_fused'):
fused_init_helper_(module, init_fn_)
else:
init_fn_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if init_div_is_residual is not False and getattr(
module, '_is_residual', False):
with torch.no_grad():
module.weight.div_(div_is_residual)
elif isinstance(module, nn.Embedding):
# Embedding
if emb_init_std is not None:
std = emb_init_std
if std == 0:
warnings.warn(f'Embedding layer initialized to 0.')
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
if verbose > 1:
warnings.warn(
f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
)
elif emb_init_uniform_lim is not None:
lim = emb_init_uniform_lim
if isinstance(lim, Sequence):
if len(lim) > 2:
raise ValueError(
f'Uniform init requires a min and a max limit. User input: {lim}.'
)
if lim[0] == lim[1]:
warnings.warn(f'Embedding layer initialized to {lim[0]}.')
else:
if lim == 0:
warnings.warn(f'Embedding layer initialized to 0.')
lim = [-lim, lim]
a, b = lim
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
if verbose > 1:
warnings.warn(
f'Embedding layer initialized using uniform distribution in range {lim}.'
)
else:
emb_init_fn_ = init_fn_
emb_init_fn_(module.weight)
elif isinstance(module, nn.LayerNorm):
# LayerNorm
if verbose > 1:
warnings.warn(
f'LayerNorm gamma weights are set to 1. If the layer has a bias it is initialized to 0.'
)
torch.nn.init.ones_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
# torch's MultiheadAttention
if module._qkv_same_embed_dim:
assert module.in_proj_weight is not None
assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
assert d_model is not None
# in_proj_weight is actually 3 layers and should be split up for width based init
_d = d_model
splits = (0, _d, 2 * _d, 3 * _d)
for s, e in zip(splits[:-1], splits[1:]):
init_fn_(module.in_proj_weight[s:e])
else:
assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
assert module.in_proj_weight is None
init_fn_(module.q_proj_weight)
init_fn_(module.k_proj_weight)
init_fn_(module.v_proj_weight)
# bias
if module.in_proj_bias is not None:
torch.nn.init.zeros_(module.in_proj_bias)
if module.bias_k is not None:
torch.nn.init.zeros_(module.bias_k)
if module.bias_v is not None:
torch.nn.init.zeros_(module.bias_v)
# out proj
init_fn_(module.out_proj.weight)
if init_div_is_residual is not False and getattr(
module.out_proj, '_is_residual', False):
with torch.no_grad():
module.out_proj.weight.div_(div_is_residual)
if module.out_proj.bias is not None:
torch.nn.init.zeros_(module.out_proj.bias)
else:
for _ in module.parameters(recurse=False):
# raise error if uninitialized module has any parameters
raise NotImplementedError(
f'{module.__class__.__name__} parameters are not initialized by param_init_fn.'
)
def _normal_init_(std, mean=0.0):
return partial(torch.nn.init.normal_, mean=mean, std=std)
def _normal_param_init_fn_(
module: nn.Module,
std: float,
n_layers: int,
d_model: Optional[int] = None,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
init_fn_ = _normal_init_(std=std)
if verbose > 1:
warnings.warn(
f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
generic_param_init_fn_(
module=module,
init_fn_=init_fn_,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=init_div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
def baseline_param_init_fn_(
module: nn.Module,
init_std: float,
n_layers: int,
d_model: Optional[int] = None,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
if init_std is None:
raise ValueError(
'You must set model.init_std to a float value to use the default initialization scheme.'
)
_normal_param_init_fn_(
module=module,
std=init_std,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=init_div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
def small_param_init_fn_(
module: nn.Module,
n_layers: int,
d_model: int,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
# very close to kaiming normal
# from Transformers without Tears (2019) - Nguyen & Salazar
std = math.sqrt(2 / (5 * d_model))
_normal_param_init_fn_(
module=module,
std=std,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=init_div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
def neox_param_init_fn_(
module: nn.Module,
n_layers: int,
d_model: int,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
verbose: int = 0,
**kwargs,
):
"""From section 2.3.1 of GPT-NeoX-20B:
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
"""
del kwargs # unused, just to capture any extra args from the config
residual_div = n_layers / math.sqrt(10) # small std / wang std
if verbose > 1:
warnings.warn(f'setting init_div_is_residual to {residual_div}')
small_param_init_fn_(
module=module,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=residual_div,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
def kaiming_uniform_param_init_fn_(
module: nn.Module,
n_layers: int,
d_model: Optional[int] = None,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
init_gain: float = 0,
fan_mode: str = 'fan_in',
init_nonlinearity: str = 'leaky_relu',
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
if verbose > 1:
warnings.warn(
f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +\
f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
)
kaiming_uniform_ = partial(nn.init.kaiming_uniform_,
a=init_gain,
mode=fan_mode,
nonlinearity=init_nonlinearity)
generic_param_init_fn_(
module=module,
init_fn_=kaiming_uniform_,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=init_div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
def kaiming_normal_param_init_fn_(
module: nn.Module,
n_layers: int,
d_model: Optional[int] = None,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
init_gain: float = 0,
fan_mode: str = 'fan_in',
init_nonlinearity: str = 'leaky_relu',
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
if verbose > 1:
warnings.warn(
f'Using nn.init.kaiming_normal_ init fn with parameters: ' +\
f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
)
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_,
a=init_gain,
mode=fan_mode,
nonlinearity=init_nonlinearity)
generic_param_init_fn_(
module=module,
init_fn_=kaiming_normal_,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=init_div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
def xavier_uniform_param_init_fn_(
module: nn.Module,
n_layers: int,
d_model: Optional[int] = None,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
init_gain: float = 0,
verbose: int = 0,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
if verbose > 1:
warnings.warn(
f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +\
f'gain={init_gain}'
)
generic_param_init_fn_(
module=module,
init_fn_=xavier_uniform_,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=init_div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
def xavier_normal_param_init_fn_(
module: nn.Module,
n_layers: int,
d_model: Optional[int] = None,
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
init_gain: float = 0,
verbose: int = 0,
**kwargs,
):
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
if verbose > 1:
warnings.warn(
f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +\
f'gain={init_gain}'
)
generic_param_init_fn_(
module=module,
init_fn_=xavier_normal_,
d_model=d_model,
n_layers=n_layers,
init_div_is_residual=init_div_is_residual,
emb_init_std=emb_init_std,
emb_init_uniform_lim=emb_init_uniform_lim,
verbose=verbose,
)
MODEL_INIT_REGISTRY = {
'default_': torch_default_param_init_fn_,
'baseline_': baseline_param_init_fn_,
'kaiming_uniform_': kaiming_uniform_param_init_fn_,
'kaiming_normal_': kaiming_normal_param_init_fn_,
'neox_init_': neox_param_init_fn_,
'small_init_': small_param_init_fn_,
'xavier_uniform_': xavier_uniform_param_init_fn_,
'xavier_normal_': xavier_normal_param_init_fn_,
}
|