|
from __future__ import annotations |
|
|
|
import os |
|
import pathlib |
|
import typing |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
|
|
str_type_map = {"fp32": torch.float32, |
|
"fp16": torch.float16, "bf16": torch.bfloat16} |
|
|
|
|
|
class BaseBelleWeights: |
|
def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, tensor_para_size, pipeline_para_size, |
|
weights_data_type: typing.Union[str, np.dtype], |
|
inference_data_type: str, |
|
has_adapters: bool = False, |
|
adapter_inter_size: int = 0, |
|
gpt_with_moe: bool = False, |
|
has_positional_encoding: bool = True, |
|
has_pre_decoder_layernorm: bool = False, |
|
has_post_decoder_layernorm: bool = True, |
|
int8_mode: int = 0, |
|
inter_size: int = 0): |
|
assert(head_num % tensor_para_size == 0) |
|
|
|
if int8_mode == 1: |
|
torch_infer_dtype = str_type_map[inference_data_type] |
|
assert torch_infer_dtype == torch.float16 or torch_infer_dtype == torch.bfloat16, "Weight only quant only supported for infer type fp16 or bf16." |
|
quant = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix |
|
self.weight_transpose_calibrate_quantize = lambda x: quant( |
|
x, torch.int8) |
|
else: |
|
assert int8_mode == 0, "Invalid int8 mode for BELLE. Must be 0 or 1" |
|
|
|
self.head_num = head_num |
|
self.size_per_head = size_per_head |
|
self.layer_num = layer_num |
|
self.vocab_size = vocab_size |
|
self.max_seq_len = max_seq_len |
|
self.tensor_para_size = tensor_para_size |
|
self.pipeline_para_size = pipeline_para_size |
|
self.layers_per_device = layer_num // pipeline_para_size |
|
|
|
self.has_adapters = has_adapters |
|
self.adapter_inter_size = adapter_inter_size |
|
self.gpt_with_moe = gpt_with_moe |
|
self.has_positional_encoding = has_positional_encoding |
|
self.has_pre_decoder_layernorm = has_pre_decoder_layernorm |
|
self.has_post_decoder_layernorm = has_post_decoder_layernorm |
|
|
|
local_head_num = head_num // tensor_para_size |
|
global_head_num = head_num |
|
local_hidden_units = local_head_num * size_per_head |
|
global_hidden_units = global_head_num * size_per_head |
|
local_inter_size = local_hidden_units * 4 |
|
if inter_size != 0: |
|
assert inter_size % tensor_para_size == 0, f"inter_size({inter_size}) \% tensor_para_size({tensor_para_size}) must be 0" |
|
local_inter_size = inter_size // tensor_para_size |
|
local_adapter_inter_size = self.adapter_inter_size // tensor_para_size |
|
|
|
self.local_head_num = local_head_num |
|
self.global_head_num = global_head_num |
|
self.local_hidden_units = local_hidden_units |
|
self.global_hidden_units = global_hidden_units |
|
self.local_inter_size = local_inter_size |
|
|
|
self.int8_mode = int8_mode |
|
self.share_embed = False |
|
|
|
if isinstance(weights_data_type, str): |
|
try: |
|
weights_data_type = { |
|
"fp16": np.float16, |
|
"fp32": np.float32, |
|
"float16": np.float16, |
|
"float32": np.float32, |
|
}[weights_data_type] |
|
except KeyError: |
|
raise ValueError( |
|
f"Don't know how to interpret weights_data_type: {weights_data_type}") |
|
|
|
assert weights_data_type in [np.float32, np.float16] |
|
self.weights_data_type = weights_data_type |
|
self.inference_data_type = inference_data_type |
|
|
|
self.w = [] |
|
self.int8_w = [] |
|
self.scale = [] |
|
|
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, local_hidden_units * 3, |
|
dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(local_hidden_units * 3, dtype=str_type_map[self.inference_data_type])] |
|
* layer_num) |
|
self.w.extend([torch.zeros(local_hidden_units, global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, local_inter_size, |
|
dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(local_inter_size, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(local_inter_size, global_hidden_units, |
|
dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
|
|
optional_adapter_offset = 0 |
|
|
|
if self.has_pre_decoder_layernorm: |
|
self.w.append(torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])) |
|
self.w.append(torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])) |
|
optional_adapter_offset += 2 |
|
if self.has_post_decoder_layernorm: |
|
self.w.append(torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])) |
|
self.w.append(torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])) |
|
optional_adapter_offset += 2 |
|
if self.has_positional_encoding: |
|
self.w.append(torch.zeros(max_seq_len, global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])) |
|
optional_adapter_offset += 1 |
|
|
|
self.pre_embed_idx = len(self.w) |
|
self.w.append(torch.zeros(vocab_size, global_hidden_units, |
|
dtype=str_type_map[self.inference_data_type])) |
|
self.post_embed_idx = len(self.w) |
|
self.w.append(torch.zeros(vocab_size, global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])) |
|
self.adapter_offset = 2 + optional_adapter_offset |
|
|
|
self.w.extend([torch.empty( |
|
0, dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.adapter_offset += layer_num |
|
|
|
|
|
if self.has_adapters: |
|
self.w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size, |
|
dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(local_adapter_inter_size, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units, |
|
dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size, |
|
dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(local_adapter_inter_size, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units, |
|
dtype=str_type_map[self.inference_data_type])] * layer_num) |
|
self.w.extend([torch.zeros(global_hidden_units, dtype=str_type_map[ |
|
self.inference_data_type])] * layer_num) |
|
|
|
|
|
self._map(lambda w: torch.nn.init.normal_(w, mean=0., std=1.)) |
|
|
|
if (self.int8_mode != 0): |
|
self.int8_w.extend([torch.zeros(global_hidden_units, local_hidden_units * |
|
3, dtype=torch.int8)] * layer_num) |
|
self.scale.extend([torch.zeros( |
|
local_hidden_units * 3, dtype=torch.float)] * layer_num) |
|
self.int8_w.extend([torch.zeros(local_hidden_units, global_hidden_units, dtype=torch.int8)] |
|
* layer_num) |
|
|
|
self.scale.extend( |
|
[torch.zeros(global_hidden_units, dtype=torch.float)] * layer_num) |
|
self.int8_w.extend([torch.zeros(global_hidden_units, local_inter_size, |
|
dtype=torch.int8)] * layer_num) |
|
self.scale.extend( |
|
[torch.zeros(local_inter_size, dtype=torch.float)] * layer_num) |
|
self.int8_w.extend([torch.zeros(local_inter_size, global_hidden_units, |
|
dtype=torch.int8)] * layer_num) |
|
self.scale.extend( |
|
[torch.zeros(global_hidden_units, dtype=torch.float)] * layer_num) |
|
if self.has_adapters: |
|
self.int8_w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size, |
|
dtype=torch.int8)] * layer_num) |
|
self.scale.extend([torch.zeros(local_adapter_inter_size, dtype=torch.float)] |
|
* layer_num) |
|
self.int8_w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units, |
|
dtype=torch.int8)] * layer_num) |
|
self.scale.extend([torch.zeros( |
|
global_hidden_units, dtype=torch.float)] * layer_num) |
|
self.int8_w.extend([torch.zeros(global_hidden_units, local_adapter_inter_size, |
|
dtype=torch.int8)] * layer_num) |
|
self.scale.extend([torch.zeros(local_adapter_inter_size, dtype=torch.float)] |
|
* layer_num) |
|
self.int8_w.extend([torch.zeros(local_adapter_inter_size, global_hidden_units, |
|
dtype=torch.int8)] * layer_num) |
|
self.scale.extend([torch.zeros( |
|
global_hidden_units, dtype=torch.float)] * layer_num) |
|
|
|
def __getitem__(self, idx): |
|
return self.w[idx] |
|
|
|
def __setitem__(self, idx, val): |
|
self.w[idx] = val |
|
|
|
def __len__(self): |
|
return len(self.w) |
|
|
|
def _map(self, func): |
|
assert(self.pre_embed_idx < self.post_embed_idx, |
|
"Pre decoder embedding index should be lower than post decoder embedding index.") |
|
for i in range(len(self.w)): |
|
if isinstance(self.w[i], list): |
|
for j in range(len(self.w[i])): |
|
self.w[i][j] = func(self.w[i][j]) |
|
else: |
|
if self.share_embed and i == self.post_embed_idx: |
|
|
|
|
|
|
|
self.w[self.post_embed_idx] = self.w[self.pre_embed_idx] |
|
else: |
|
self.w[i] = func(self.w[i]) |
|
|
|
def _map_int8(self, func): |
|
for i in range(len(self.int8_w)): |
|
if isinstance(self.int8_w[i], list): |
|
for j in range(len(self.int8_w[i])): |
|
self.int8_w[i][j] = func(self.int8_w[i][j]) |
|
|
|
else: |
|
self.int8_w[i] = func(self.int8_w[i]) |
|
for i in range(len(self.scale)): |
|
if isinstance(self.scale[i], list): |
|
for j in range(len(self.scale[i])): |
|
self.scale[i][j] = func(self.scale[i][j]) |
|
|
|
else: |
|
self.scale[i] = func(self.scale[i]) |
|
|
|
def _map_int8_scales(self, func): |
|
for i in range(len(self.scale)): |
|
if isinstance(self.scale[i], list): |
|
for j in range(len(self.scale[i])): |
|
self.scale[i][j] = func(self.scale[i][j]) |
|
|
|
else: |
|
self.scale[i] = func(self.scale[i]) |
|
|
|
def load(self, ckpt_path, tp_rank, pipeline_para_rank): |
|
if not os.path.exists(ckpt_path): |
|
raise FileNotFoundError(f"Failed to find {ckpt_path}") |
|
w = [] |
|
|
|
type_map = {np.float32: torch.float32, np.float16: torch.float16} |
|
|
|
|
|
def is_load(i): return i >= self.layers_per_device * \ |
|
pipeline_para_rank and i < self.layers_per_device * \ |
|
(pipeline_para_rank + 1) |
|
|
|
def load_to_torch(npdata: str, is_load: bool): |
|
if is_load: |
|
return torch.from_numpy(npdata).to(str_type_map[self.inference_data_type]) |
|
|
|
else: |
|
return torch.empty(0).to(str_type_map[self.inference_data_type]) |
|
|
|
|
|
def get_np_data(h5f, layername, layer_num, weight_type, tp_rank=None): |
|
if tp_rank is None: |
|
return [load_to_torch(h5f[f'model.layers.{i}.{layername}.{weight_type}']["weights"][:], is_load(i)) for i in range(layer_num)] |
|
else: |
|
return [load_to_torch(h5f[f'model.layers.{i}.{layername}.{weight_type}.{tp_rank}']["weights"][:], is_load(i)) for i in range(layer_num)] |
|
|
|
def get_np_data_single(h5f, layername, weight_type, is_loaded, tp_rank=None): |
|
if weight_type is None: |
|
return load_to_torch(h5f[f'model.{layername}']["weights"][:], is_loaded) |
|
|
|
if tp_rank is None: |
|
return load_to_torch(h5f[f'model.{layername}.{weight_type}']["weights"][:], is_loaded) |
|
else: |
|
return load_to_torch(h5f[f'model.{layername}.{weight_type}.{tp_rank}']["weights"][:], is_loaded) |
|
|
|
import h5py |
|
ckpt_f = h5py.File(ckpt_path, "r") |
|
|
|
w.extend(get_np_data(ckpt_f, "input_layernorm", self.layer_num, "weight")) |
|
w.extend(get_np_data(ckpt_f, "input_layernorm", self.layer_num, "bias")) |
|
|
|
w.extend(get_np_data(ckpt_f, "attention.query_key_value", self.layer_num, "weight", tp_rank)) |
|
w.extend(get_np_data(ckpt_f, "attention.query_key_value", self.layer_num, "bias", tp_rank)) |
|
|
|
w.extend(get_np_data(ckpt_f, "attention.dense", self.layer_num, "weight", tp_rank)) |
|
w.extend(get_np_data(ckpt_f, "attention.dense", self.layer_num, "bias")) |
|
|
|
w.extend(get_np_data(ckpt_f, "post_attention_layernorm", self.layer_num, "weight")) |
|
w.extend(get_np_data(ckpt_f, "post_attention_layernorm", self.layer_num, "bias")) |
|
|
|
|
|
w.extend(get_np_data(ckpt_f, "mlp.dense_h_to_4h", self.layer_num, "weight", tp_rank)) |
|
w.extend(get_np_data(ckpt_f, "mlp.dense_h_to_4h", self.layer_num, "bias", tp_rank)) |
|
|
|
|
|
w.extend(get_np_data(ckpt_f, "mlp.dense_4h_to_h", self.layer_num, "weight", tp_rank)) |
|
w.extend(get_np_data(ckpt_f, "mlp.dense_4h_to_h", self.layer_num, "bias")) |
|
|
|
|
|
|
|
if self.has_pre_decoder_layernorm: |
|
w.append(get_np_data_single(ckpt_f, "pre_decoder_layernorm", "weight", True)) |
|
w.append(get_np_data_single(ckpt_f, "pre_decoder_layernorm", "bias", True)) |
|
|
|
|
|
if self.has_post_decoder_layernorm: |
|
w.append(get_np_data_single(ckpt_f, "final_layernorm", "weight", True)) |
|
w.append(get_np_data_single(ckpt_f, "final_layernorm", "bias", True)) |
|
|
|
|
|
if self.has_positional_encoding: |
|
wpe = load_to_torch(get_np_data_single(ckpt_f, "wpe", weight_type=None, is_loaded=True)).reshape(-1, self.global_hidden_units) |
|
assert self.max_seq_len <= wpe.size(0), ( |
|
f"max_seq_len ({self.max_seq_len} must not exceed " |
|
f"the value of maximum sequence length during training ({wpe.size(0)})." |
|
) |
|
w.append(wpe) |
|
|
|
w.append(get_np_data_single(ckpt_f, "wte", weight_type=None, is_loaded=True)) |
|
|
|
if "model.lm_head.weight" in ckpt_f.keys(): |
|
self.share_embed = False |
|
w.append(get_np_data_single(ckpt_f, "lm_head", "weight", True)) |
|
else: |
|
self.share_embed = True |
|
w.append(torch.empty(0).to(str_type_map[self.inference_data_type])) |
|
|
|
gate_list = [] |
|
for i in range(self.layer_num): |
|
print(">>>???>>") |
|
if f"model.layers.{i}.mlp.moe.gate.wg.weight" in ckpt_f.keys(): |
|
gate_list.append(load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.mlp.moe.gate.wg.weight.bin", True)) |
|
else: |
|
gate_list.append(load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.mlp.moe.gate.wg.weight.bin", False)) |
|
w.extend(gate_list) |
|
""" |
|
if self.has_adapters: |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_h_to_4h.weight.{tp_rank}.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_h_to_4h.weight.{tp_rank}.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_attention_adapter.moe.experts.dense_h_to_4h.weight.{tp_rank}.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_h_to_4h.bias.{tp_rank}.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_h_to_4h.bias.{tp_rank}.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_attention_adapter.moe.experts.dense_h_to_4h.bias.{tp_rank}.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_4h_to_h.weight.{tp_rank}.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_4h_to_h.weight.{tp_rank}.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_attention_adapter.moe.experts.dense_4h_to_h.weight.{tp_rank}.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_4h_to_h.bias.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_attention_adapter.dense_4h_to_h.bias.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_attention_adapter.moe.experts.dense_4h_to_h.bias.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_h_to_4h.weight.{tp_rank}.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_h_to_4h.weight.{tp_rank}.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.moe.experts.dense_h_to_4h.weight.{tp_rank}.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_h_to_4h.bias.{tp_rank}.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_h_to_4h.bias.{tp_rank}.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.moe.experts.dense_h_to_4h.bias.{tp_rank}.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_4h_to_h.weight.{tp_rank}.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_4h_to_h.weight.{tp_rank}.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.moe.experts.dense_4h_to_h.weight.{tp_rank}.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
w.extend([load_to_torch( |
|
f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_4h_to_h.bias.bin" |
|
if os.path.isfile(f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.dense_4h_to_h.bias.bin") |
|
else f"{ckpt_path}/model.layers.{i}.after_ffn_adapter.moe.experts.dense_4h_to_h.bias.bin", |
|
is_load(i)) for i in range(self.layer_num)]) |
|
""" |
|
assert len(self.w) == len(w) |
|
|
|
|
|
try: |
|
for i in range(len(w)): |
|
if w[i].nelement() == self.w[i].nelement(): |
|
self.w[i] = w[i].reshape(self.w[i].shape) |
|
else: |
|
self.w[i] = w[i] |
|
|
|
except RuntimeError: |
|
raise RuntimeError( |
|
f"head_num, size_per_head, vocab_size, and max_seq_len must be the same as the ones during training " |
|
f"(idx: {i} expected shape: {self.w[i].shape} got shape: {w[i].shape})." |
|
) |
|
|
|
|
|
layer_num = self.layer_num |
|
if self.int8_mode != 0: |
|
for i in range(layer_num): |
|
self.int8_w[i + 0 * layer_num], self.scale[i + 0 * |
|
layer_num] = self.weight_transpose_calibrate_quantize(self.w[2 * layer_num + i]) |
|
self.int8_w[i + 1 * layer_num], self.scale[i + 1 * |
|
layer_num] = self.weight_transpose_calibrate_quantize(self.w[4 * layer_num + i]) |
|
self.int8_w[i + 2 * layer_num], self.scale[i + 2 * |
|
layer_num] = self.weight_transpose_calibrate_quantize(self.w[8 * layer_num + i]) |
|
self.int8_w[i + 3 * layer_num], self.scale[i + 3 * |
|
layer_num] = self.weight_transpose_calibrate_quantize(self.w[10 * layer_num + i]) |
|
|
|
|
|
if self.int8_mode == 1: |
|
self.w[2 * layer_num + |
|
i] = torch.empty(0).to(str_type_map[self.inference_data_type]) |
|
self.w[4 * layer_num + |
|
i] = torch.empty(0).to(str_type_map[self.inference_data_type]) |
|
self.w[8 * layer_num + |
|
i] = torch.empty(0).to(str_type_map[self.inference_data_type]) |
|
self.w[10 * layer_num + |
|
i] = torch.empty(0).to(str_type_map[self.inference_data_type]) |
|
|
|
if self.has_adapters: |
|
self.int8_w[i + 4 * layer_num], self.scale[i + 4 * layer_num] = self.weight_transpose_calibrate_quantize( |
|
self.w[12 * layer_num + i + self.adapter_offset]) |
|
self.int8_w[i + 5 * layer_num], self.scale[i + 5 * layer_num] = self.weight_transpose_calibrate_quantize( |
|
self.w[14 * layer_num + i + self.adapter_offset]) |
|
self.int8_w[i + 6 * layer_num], self.scale[i + 6 * layer_num] = self.weight_transpose_calibrate_quantize( |
|
self.w[16 * layer_num + i + self.adapter_offset]) |
|
self.int8_w[i + 7 * layer_num], self.scale[i + 7 * layer_num] = self.weight_transpose_calibrate_quantize( |
|
self.w[18 * layer_num + i + self.adapter_offset]) |
|
|
|
|
|
if self.int8_mode == 1: |
|
self.w[12 * layer_num + i + self.adapter_offset] = torch.empty( |
|
0).to(str_type_map[self.inference_data_type]) |
|
self.w[14 * layer_num + i + self.adapter_offset] = torch.empty( |
|
0).to(str_type_map[self.inference_data_type]) |
|
self.w[16 * layer_num + i + self.adapter_offset] = torch.empty( |
|
0).to(str_type_map[self.inference_data_type]) |
|
self.w[18 * layer_num + i + self.adapter_offset] = torch.empty( |
|
0).to(str_type_map[self.inference_data_type]) |
|
return True |
|
|
|
|
|
class BaseBelleModel(nn.Module): |
|
def __init__(self, |
|
head_num, size_per_head, |
|
vocab_size, start_id, end_id, layer_num, |
|
max_seq_len: int, |
|
tensor_para_size: int, |
|
pipeline_para_size: int, |
|
lib_path: typing.Union[str, pathlib.Path], |
|
inference_data_type: str, |
|
inter_size: int = 0, |
|
|
|
layernorm_eps: float = 1e-6, |
|
layernorm_type: typing.Literal['pre_layernorm', |
|
'post_layernorm'] = "pre_layernorm", |
|
activation_type: str = "Gelu", |
|
gpt_with_moe: bool = False, |
|
expert_num: int = 0, |
|
moe_k: int = 0, |
|
moe_layer_index: typing.List = [], |
|
has_positional_encoding: bool = True, |
|
has_pre_decoder_layernorm: bool = False, |
|
has_post_decoder_layernorm: bool = True, |
|
has_adapters: bool = False, |
|
adapter_inter_size: int = 0, |
|
use_attention_linear_bias: bool = False, |
|
int8_mode: int = 0, |
|
weights_data_type: typing.Union[str, np.dtype] = np.float32, |
|
shared_contexts_ratio: float = 1.0): |
|
super().__init__() |
|
self.head_num = head_num |
|
self.size_per_head = size_per_head |
|
self.vocab_size = vocab_size |
|
self.start_id = start_id |
|
self.end_id = end_id |
|
self.layer_num = layer_num |
|
self.inter_size = inter_size if inter_size != 0 else 4 * \ |
|
self.head_num * self.size_per_head |
|
|
|
|
|
self.layernorm_eps = layernorm_eps |
|
self.layernorm_type = layernorm_type |
|
self.activation_type = activation_type |
|
self.gpt_with_moe = gpt_with_moe |
|
self.expert_num = expert_num |
|
self.moe_k = moe_k |
|
self.moe_layer_index = moe_layer_index |
|
self.has_positional_encoding = has_positional_encoding |
|
self.has_pre_decoder_layernorm = has_pre_decoder_layernorm |
|
self.has_post_decoder_layernorm = has_post_decoder_layernorm |
|
self.has_adapters = has_adapters |
|
self.adapter_inter_size = adapter_inter_size |
|
self.use_attention_linear_bias = use_attention_linear_bias |
|
|
|
|
|
self.tensor_para_size = tensor_para_size |
|
self.pipeline_para_size = pipeline_para_size |
|
self.use_sparse_gemm = False |
|
self.build_model = False |
|
self.int8_mode = int8_mode |
|
self.weights_data_type = weights_data_type |
|
self.shared_contexts_ratio = shared_contexts_ratio |
|
|
|
assert torch.cuda.is_available(), "CUDA is required for this model." |
|
|
|
assert head_num % tensor_para_size == 0, "head_num must be a multiple of tensor_para_size." |
|
assert layer_num % pipeline_para_size == 0, "layer_num must be a multiple of pipeline_para_size." |
|
|
|
|
|
torch.classes.load_library(os.path.abspath(lib_path)) |
|
|
|
|
|
self.weights = BaseBelleWeights(head_num, size_per_head, layer_num, vocab_size, |
|
max_seq_len, tensor_para_size, pipeline_para_size, |
|
weights_data_type=weights_data_type, |
|
inference_data_type=inference_data_type, |
|
gpt_with_moe=self.gpt_with_moe, |
|
has_positional_encoding=self.has_positional_encoding, |
|
has_pre_decoder_layernorm=self.has_pre_decoder_layernorm, |
|
has_post_decoder_layernorm=self.has_post_decoder_layernorm, |
|
has_adapters=self.has_adapters, |
|
adapter_inter_size=self.adapter_inter_size, |
|
int8_mode=int8_mode, |
|
inter_size=inter_size) |
|
|
|
|
|
try: |
|
dist.init_process_group(backend='mpi') |
|
except: |
|
print("[INFO] WARNING: Have initialized the process group") |
|
self.rank = dist.get_rank() |
|
self.device_count = torch.cuda.device_count() |
|
self.device = self.rank % self.device_count |
|
torch.cuda.set_device(self.device) |
|
|
|
world_size = dist.get_world_size() |
|
assert world_size == tensor_para_size * \ |
|
pipeline_para_size, "tensor_para_size * pipeline_para_size must be equal to world_size." |
|
|
|
self.tensor_para_rank = self.rank % self.tensor_para_size |
|
self.pipeline_para_rank = self.rank // self.tensor_para_size |
|
|
|
def load(self, ckpt_path): |
|
is_load = self.weights.load(ckpt_path, tp_rank=self.tensor_para_rank, |
|
pipeline_para_rank=self.pipeline_para_rank) |
|
self.cuda() |
|
torch.cuda.empty_cache() |
|
return is_load |
|
|
|
def sparse(self): |
|
if not self.use_sparse_gemm: |
|
self.use_sparse_gemm = True |
|
|
|
def cuda(self): |
|
self.weights._map(lambda w: w.cuda(self.device)) |
|
if self.int8_mode != 0: |
|
self.weights._map_int8(lambda w: w.cuda(self.device)) |
|
|
|
if self.build_model: |
|
del self.model |
|
self.build_model = False |
|
|
|
self.model = torch.classes.FasterTransformer.GptOp( |
|
self.head_num, self.size_per_head, self.inter_size, |
|
self.layer_num, |
|
self.expert_num, |
|
self.moe_k, |
|
self.moe_layer_index, |
|
self.vocab_size, self.start_id, self.end_id, |
|
self.use_sparse_gemm, |
|
|
|
self.layernorm_eps, |
|
self.layernorm_type, |
|
self.activation_type, |
|
self.has_positional_encoding, |
|
self.has_pre_decoder_layernorm, |
|
self.has_post_decoder_layernorm, |
|
self.has_adapters, |
|
self.adapter_inter_size, |
|
self.use_attention_linear_bias, |
|
self.weights.w) |
|
self.build_model = True |
|
|
|
def forward(self, |
|
start_ids: torch.IntTensor, |
|
start_lengths: torch.IntTensor, |
|
output_len: int, |
|
beam_width: int = 1, |
|
top_k: typing.Optional[torch.IntTensor] = None, |
|
top_p: typing.Optional[torch.FloatTensor] = None, |
|
beam_search_diversity_rate: typing.Optional[torch.FloatTensor] = None, |
|
temperature: typing.Optional[torch.FloatTensor] = None, |
|
len_penalty: typing.Optional[torch.FloatTensor] = None, |
|
repetition_penalty: typing.Optional[torch.FloatTensor] = None, |
|
presence_penalty: typing.Optional[torch.FloatTensor] = None, |
|
min_length: typing.Optional[torch.IntTensor] = None, |
|
random_seed: typing.Optional[torch.LongTensor] = None, |
|
bad_words_list: typing.Optional[torch.IntTensor] = None, |
|
return_output_length: bool = False, |
|
return_cum_log_probs: int = 0): |
|
if not self.build_model: |
|
|
|
self.cuda() |
|
torch.cuda.empty_cache() |
|
input_len = start_ids.size(1) |
|
assert input_len > 0, "input len must be larger than zero. For an unconditional case, use start_id as the first token." |
|
|
|
|
|
start_ids = start_ids.cuda(self.device) |
|
start_lengths = start_lengths.cuda(self.device) |
|
|
|
outputs = self.model.forward(start_ids, |
|
start_lengths, |
|
output_len, |
|
beam_width, |
|
top_k, |
|
top_p, |
|
beam_search_diversity_rate, |
|
temperature, |
|
len_penalty, |
|
repetition_penalty, |
|
presence_penalty, |
|
min_length, |
|
random_seed, |
|
bad_words_list, |
|
return_cum_log_probs) |
|
if return_cum_log_probs == 0: |
|
output_ids, output_lengths = outputs |
|
else: |
|
output_ids, output_lengths, output_cum_log_probs = outputs |
|
if return_output_length: |
|
if return_cum_log_probs > 0: |
|
return output_ids, output_lengths, output_cum_log_probs |
|
else: |
|
return output_ids, output_lengths |
|
else: |
|
return output_ids |
|
|
|
def set_input_tensor(self, input_tensor): |
|
"""Set input tensor to be used instead of forward()'s input. |
|
|
|
When doing pipeline parallelism the input from the previous |
|
stage comes from communication, not from the input, so the |
|
model's forward_step_func won't have it. This function is thus |
|
used by internal code to bypass the input provided by the |
|
forward_step_func""" |
|
self.input_tensor = input_tensor |
|
|
|
|
|
class BaseParallelBelleModel(BaseBelleModel): |
|
|
|
def cuda(self): |
|
self.weights._map(lambda w: w.cuda(self.device)) |
|
if self.int8_mode != 0: |
|
self.weights._map_int8(lambda w: w.cuda(self.device)) |
|
|
|
if self.build_model: |
|
del self.model |
|
self.build_model = False |
|
self.model = torch.classes.FasterTransformer.ParallelGptOp( |
|
self.head_num, self.size_per_head, self.inter_size, |
|
self.layer_num, |
|
self.expert_num, |
|
self.moe_k, |
|
self.moe_layer_index, |
|
self.vocab_size, self.start_id, self.end_id, |
|
self.tensor_para_size, self.pipeline_para_size, self.int8_mode, |
|
|
|
self.layernorm_eps, |
|
self.layernorm_type, |
|
self.activation_type, |
|
self.has_positional_encoding, |
|
self.has_pre_decoder_layernorm, |
|
self.has_post_decoder_layernorm, |
|
self.has_adapters, |
|
self.adapter_inter_size, |
|
self.use_attention_linear_bias, |
|
self.weights.w, |
|
self.weights.int8_w, |
|
self.weights.scale, |
|
self.shared_contexts_ratio) |
|
self.build_model = True |
|
|
|
|
|
class BelleWeight(BaseBelleWeights): |
|
|
|
def __init__(self, head_num, size_per_head, layer_num, vocab_size, |
|
tensor_para_size, pipeline_para_size, weights_data_type, inference_data_type, |
|
int8_mode=0): |
|
super().__init__( |
|
head_num, size_per_head, layer_num, vocab_size, 0, |
|
tensor_para_size, pipeline_para_size, weights_data_type, |
|
inference_data_type, |
|
has_adapters=False, |
|
adapter_inter_size=0, |
|
has_positional_encoding=False, |
|
has_pre_decoder_layernorm=True, |
|
has_post_decoder_layernorm=True, |
|
int8_mode=int8_mode) |
|
|
|
|
|
class BelleModel(BaseParallelBelleModel): |
|
|
|
def __init__(self, |
|
head_num, size_per_head, |
|
vocab_size, start_id, end_id, layer_num, |
|
tensor_para_size: int, |
|
pipeline_para_size: int, |
|
lib_path: str | Path, |
|
inference_data_type: str, |
|
weights_data_type: str | np.dtype = np.float32, |
|
layernorm_eps: float = 1e-5, |
|
shared_contexts_ratio: float = 1.0, |
|
int8_mode: int = 0): |
|
super().__init__( |
|
head_num, size_per_head, vocab_size, start_id, end_id, layer_num, |
|
0, tensor_para_size, pipeline_para_size, |
|
lib_path=lib_path, |
|
inference_data_type=inference_data_type, |
|
layernorm_eps=layernorm_eps, |
|
|
|
layernorm_type="pre_layernorm", |
|
activation_type="Gelu", |
|
has_positional_encoding=False, |
|
has_pre_decoder_layernorm=True, |
|
has_post_decoder_layernorm=True, |
|
has_adapters=False, |
|
adapter_inter_size=0, |
|
use_attention_linear_bias=True, |
|
int8_mode=int8_mode, |
|
weights_data_type=weights_data_type, |
|
shared_contexts_ratio=shared_contexts_ratio) |
|
|
|
def set_input_tensor(self, input_tensor: Optional[torch.Tensor]): |
|
"""Set input tensor to be used instead of forward()'s input. |
|
|
|
When doing pipeline parallelism the input from the previous |
|
stage comes from communication, not from the input, so the |
|
model's forward_step_func won't have it. This function is thus |
|
used by internal code to bypass the input provided by the |
|
forward_step_func |
|
""" |
|
self.input_tensor = input_tensor |
|
|