File size: 8,264 Bytes
a93e458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module"""
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron import get_args
from megatron.core import mpu, tensor_parallel
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self):
if self.pre_process:
if self.language_model.tie_embed_logits:
return self.language_model.embedding.word_embeddings.weight
return self.language_model.lm_head
else:
if not self.language_model.tie_embed_logits:
return self.language_model.lm_head
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false')
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal, args):
if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage() and \
not self.pre_process:
assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
self.pre_process:
self.language_model.embedding.zero_parameters()
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Ensure that first and last stages have the same initial parameter
# values.
if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if mpu.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_position_embedding_group())
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(MegatronModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
if args.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif args.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('should not be here')
self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
return self.module.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
|