Spaces:
Build error
Build error
import logging | |
import re | |
from typing import * | |
import torch | |
from allennlp.common.from_params import Params, T | |
from allennlp.training.optimizers import Optimizer | |
logger = logging.getLogger('optim') | |
class TransformerOptimizer: | |
""" | |
Wrapper for AllenNLP optimizer. | |
This is used to fine-tune the pretrained transformer with some layers fixed and different learning rate. | |
When some layers are fixed, the wrapper will set the `require_grad` flag as False, which could save | |
training time and optimize memory usage. | |
Plz contact Guanghui Qin for bugs. | |
Params: | |
base: base optimizer. | |
embeddings_lr: learning rate for embedding layer. Set as 0.0 to fix it. | |
encoder_lr: learning rate for encoder layer. Set as 0.0 to fix it. | |
pooler_lr: learning rate for pooler layer. Set as 0.0 to fix it. | |
layer_fix: the number of encoder layers that should be fixed. | |
Example json config: | |
1. No-op. Do nothing (why do you use me?) | |
optimizer: { | |
type: "transformer", | |
base: { | |
type: "adam", | |
lr: 0.001 | |
} | |
} | |
2. Fix everything in the transformer. | |
optimizer: { | |
type: "transformer", | |
base: { | |
type: "adam", | |
lr: 0.001 | |
}, | |
embeddings_lr: 0.0, | |
encoder_lr: 0.0, | |
pooler_lr: 0.0 | |
} | |
Or equivalently (suppose we have 24 layers) | |
optimizer: { | |
type: "transformer", | |
base: { | |
type: "adam", | |
lr: 0.001 | |
}, | |
embeddings_lr: 0.0, | |
layer_fix: 24, | |
pooler_lr: 0.0 | |
} | |
3. Fix embeddings and the lower 12 encoder layers, set a small learning rate | |
for the other parts of the transformer | |
optimizer: { | |
type: "transformer", | |
base: { | |
type: "adam", | |
lr: 0.001 | |
}, | |
embeddings_lr: 0.0, | |
layer_fix: 12, | |
encoder_lr: 1e-5, | |
pooler_lr: 1e-5 | |
} | |
""" | |
def from_params( | |
cls: Type[T], | |
params: Params, | |
model_parameters: List[Tuple[str, torch.nn.Parameter]], | |
**_ | |
): | |
param_groups = list() | |
def remove_param(keyword_): | |
nonlocal model_parameters | |
logger.info(f'Fix param with name matching {keyword_}.') | |
for name, param in model_parameters: | |
if keyword_ in name: | |
logger.debug(f'Fix param {name}.') | |
param.requires_grad_(False) | |
model_parameters = list(filter(lambda x: keyword_ not in x[0], model_parameters)) | |
for i_layer in range(params.pop('layer_fix')): | |
remove_param('transformer_model.encoder.layer.{}.'.format(i_layer)) | |
for specific_lr, keyword in ( | |
(params.pop('embeddings_lr', None), 'transformer_model.embeddings'), | |
(params.pop('encoder_lr', None), 'transformer_model.encoder.layer'), | |
(params.pop('pooler_lr', None), 'transformer_model.pooler'), | |
): | |
if specific_lr is not None: | |
if specific_lr > 0.: | |
pattern = '.*' + keyword.replace('.', r'\.') + '.*' | |
if len([name for name, _ in model_parameters if re.match(pattern, name)]) > 0: | |
param_groups.append([[pattern], {'lr': specific_lr}]) | |
else: | |
logger.warning(f'{pattern} is set to use lr {specific_lr} but no param matches.') | |
else: | |
remove_param(keyword) | |
if 'parameter_groups' in params: | |
for pg in params.pop('parameter_groups'): | |
param_groups.append([pg[0], pg[1].as_dict()]) | |
return Optimizer.by_name(params.get('base').pop('type'))( | |
model_parameters=model_parameters, parameter_groups=param_groups, | |
**params.pop('base').as_flat_dict() | |
) | |