File size: 6,735 Bytes
634c636 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import argparse
from itertools import chain
import torch
import torch.nn as nn
from transformers import LlamaConfig, DynamicCache
from midi_model import MIDIModel, config_name_list, MIDIModelConfig
class MIDIModelBase(nn.Module):
def __init__(self, model):
super().__init__()
self.net = model.net
def forward(self, x, past_kv):
cache = DynamicCache.from_legacy_cache(past_kv)
x = self.net.embed_tokens(x)
x = x.sum(dim=-2)
x = self.net.forward(inputs_embeds=x,
past_key_values=cache,
use_cache=True)
return x.last_hidden_state, cache.to_legacy_cache()
class MIDIModelToken(nn.Module):
def __init__(self, model):
super().__init__()
self.net_token = model.net_token
self.lm_head = model.lm_head
def forward(self, hidden_state, x, past_kv):
cache = DynamicCache.from_legacy_cache(past_kv)
x = self.net_token.embed_tokens(x)
x = torch.cat([hidden_state, x], dim=1)
hidden_state = x
hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
past_key_values=cache,
use_cache=True).last_hidden_state
return self.lm_head(hidden_state), cache.to_legacy_cache()
def export_onnx(model, model_inputs, input_names, output_names, dynamic_axes, meta_data, path):
import onnx
from onnxsim import simplify
torch.onnx.export(model, # model being run
model_inputs, # model input (or a tuple for multiple inputs)
path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=14, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=input_names, # the model's input names
output_names=output_names, # the model's output names
verbose=True,
dynamic_axes=dynamic_axes
)
onnx_model = onnx.load(path)
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
for k, v in meta_data.items():
meta = model_simp.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_simp, path)
print('finished exporting onnx')
def get_past_kv(config: LlamaConfig, batch_size=1, past_seq_len=16, torch_dtype= torch.float32, device="cpu"):
head_size = config.hidden_size // config.num_attention_heads
past_kv = [
(
torch.rand(batch_size, config.num_attention_heads,
past_seq_len, head_size, dtype=torch_dtype, device=device),
torch.rand(batch_size, config.num_attention_heads,
past_seq_len, head_size, dtype=torch_dtype, device=device),
)
for _ in range(config.num_hidden_layers)
]
input_names = list(
chain.from_iterable(
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in
range(config.num_hidden_layers)
)
)
output_names = list(
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(config.num_hidden_layers))
)
return past_kv, input_names, output_names
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt", type=str, default="model.ckpt", help="load ckpt"
)
parser.add_argument(
"--config", type=str, default="tv2o-medium", choices=config_name_list, help="model config"
)
parser.add_argument(
"--lora", type=str, default="", help="load lora"
)
parser.add_argument(
"--model-base-out", type=str, default="model_base.onnx", help="model base output path"
)
parser.add_argument(
"--model-token-out", type=str, default="model_token.onnx", help="model token output path"
)
opt = parser.parse_args()
config = MIDIModelConfig.from_name(opt.config)
tokenizer = config.tokenizer
model = MIDIModel(config).to(device="cpu")
ckpt = torch.load(opt.ckpt, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
model.load_state_dict(state_dict, strict=False)
if opt.lora != "":
model.load_merge_lora(opt.lora)
model.eval()
model_base = MIDIModelBase(model).eval()
model_token = MIDIModelToken(model).eval()
meta_data = {"config_name": opt.config, "config": config}
past_kv_shape = {0: "batch", 2: "past_seq"}
present_kv_shape = {0: "batch", 2: "present_seq"}
with torch.no_grad():
dynamic_axes = {
"x": {0: "batch", 1: "mid_seq", 2: "token_seq"},
"hidden": {0: "batch", 1: "mid_seq"}
}
x = torch.randint(tokenizer.vocab_size, (1, 16, tokenizer.max_token_seq), dtype=torch.int64, device="cpu")
past_kv, input_names, output_names= get_past_kv(config.net_config, past_seq_len=16,
torch_dtype=torch.float32)
for name in input_names:
dynamic_axes[name] = past_kv_shape
for name in output_names:
dynamic_axes[name] = present_kv_shape
input_names = [ "x"] + input_names
output_names = ["hidden"] + output_names
export_onnx(model_base, (x, past_kv),
input_names, output_names, dynamic_axes, meta_data, opt.model_base_out)
dynamic_axes = {
"x": {0: "batch", 1: "token_seq"},
"hidden": {0: "batch", 1: "states"},
"y": {0: "batch", 1: "token_seq1"}
}
hidden = torch.randn(1, 1, config.n_embd, device="cpu")
x = torch.randint(tokenizer.vocab_size, (1, tokenizer.max_token_seq //2), dtype=torch.int64, device="cpu")
past_kv, input_names, output_names = get_past_kv(config.net_token_config,
past_seq_len=(tokenizer.max_token_seq // 2),
torch_dtype=torch.float32)
for name in input_names:
dynamic_axes[name] = past_kv_shape
for name in output_names:
dynamic_axes[name] = present_kv_shape
input_names = ["hidden", "x"] + input_names
output_names = ["y"] + output_names
export_onnx(model_token, (hidden, x, past_kv),
input_names, output_names, dynamic_axes, meta_data, opt.model_token_out)
|