|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import argparse |
|
from tqdm import tqdm |
|
from transformers import AutoModelForCausalLM |
|
torch.set_grad_enabled(False) |
|
|
|
parser = argparse.ArgumentParser(description='export onnx') |
|
parser.add_argument('-m', '--model_path', type=str, help='path to the torch model') |
|
parser.add_argument('-s', '--seq_length', type=int, default=512, help="sequence length") |
|
parser.add_argument('-d', '--device', type=str, choices=["cpu", "cuda"], default="cpu") |
|
parser.add_argument('--save_dir', type=str, default=f"./tmp/onnx") |
|
parser.add_argument('--guess_length', type=int, default=5) |
|
|
|
args = parser.parse_args() |
|
|
|
model_path = args.model_path |
|
folder = args.save_dir |
|
|
|
device = torch.device(args.device) |
|
origin_model = AutoModelForCausalLM.from_pretrained( |
|
model_path, trust_remote_code=True, |
|
torch_dtype=torch.float).eval() |
|
|
|
for param in origin_model.parameters(): |
|
param.requires_grad = False |
|
|
|
config = origin_model.config |
|
transformer = origin_model.model |
|
layers = transformer.layers |
|
|
|
SEQ_LENGTH = args.seq_length |
|
NUM_LAYERS = config.num_hidden_layers |
|
HIDDEN_SIZE = config.hidden_size |
|
NUM_ATTENTION_HEADS = config.num_attention_heads |
|
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS |
|
VOCAB_SIZE = config.vocab_size |
|
GUESS_LENGTH = args.guess_length |
|
|
|
print(f'Layers: {NUM_LAYERS}\nHidden size: {HIDDEN_SIZE}\n') |
|
|
|
class Embedding(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, input_ids): |
|
out = transformer.embed_tokens(input_ids) |
|
return out.float() |
|
|
|
|
|
class QwenBlock(torch.nn.Module): |
|
|
|
def __init__(self, layer_id): |
|
super().__init__() |
|
self.layer_id = layer_id |
|
self.layer = layers[layer_id] |
|
|
|
def forward(self, hidden_states, position_ids, attention_mask): |
|
hidden_states, past_kv = self.layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
use_cache=True) |
|
present_k, present_v = past_kv |
|
return hidden_states.float(), present_k.float(), present_v.float() |
|
|
|
|
|
class QwenBlockCache(torch.nn.Module): |
|
|
|
def __init__(self, layer_id): |
|
super().__init__() |
|
self.layer_id = layer_id |
|
self.layer = layers[layer_id] |
|
|
|
def forward(self, hidden_states, position_ids, attention_mask, past_k, |
|
past_v): |
|
hidden_states, past_kv = self.layer( |
|
hidden_states, |
|
past_key_value=(past_k, past_v), |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
use_cache=True) |
|
present_k, present_v = past_kv |
|
return hidden_states.float(), present_k.float(), present_v.float() |
|
|
|
|
|
class LmHead(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = transformer.norm(hidden_states) |
|
m_logits = origin_model.lm_head(hidden_states) |
|
return m_logits |
|
|
|
|
|
class GreedyHead(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, m_logits): |
|
_, token = torch.topk(m_logits.float(), 1, dim=1) |
|
return token |
|
|
|
|
|
|
|
class PenaltySampleHead(torch.nn.Module): |
|
|
|
def __init__(self, top_k = 50, min_tokens_to_keep = 5): |
|
super().__init__() |
|
self.top_k = top_k |
|
self.min_tokens_to_keep = min_tokens_to_keep |
|
self.keep_matrix = torch.zeros((1, self.top_k), dtype=torch.bool) |
|
self.keep_matrix[0, :self.min_tokens_to_keep] = True |
|
|
|
def forward(self, m_logits, input_ids, top_p, temperature, penalty): |
|
|
|
logits = torch.gather(m_logits, 1, input_ids) |
|
logits = torch.where(logits < 0, logits * penalty, logits / penalty) |
|
m_logits.scatter_(1, input_ids, logits) |
|
|
|
|
|
logits, token = torch.topk(m_logits.float(), self.top_k, dim=1) |
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
cumulative_probs = logits.softmax(dim=1).cumsum(dim=1) |
|
mask = cumulative_probs < top_p |
|
mask = mask + self.keep_matrix |
|
filtered_logits = torch.where(mask, logits, torch.FloatTensor([-1000.])) |
|
probs = filtered_logits.softmax(dim=1) |
|
return probs, token |
|
|
|
|
|
def convert_block(layer_id): |
|
model = QwenBlock(layer_id) |
|
hidden_states = torch.randn( |
|
(1, SEQ_LENGTH, HIDDEN_SIZE)).to(torch.float).to(device) |
|
position_ids = torch.tensor( |
|
[range(SEQ_LENGTH)], dtype=torch.long).to(device) |
|
attention_mask = torch.randn( |
|
(1, 1, SEQ_LENGTH, SEQ_LENGTH)).to(torch.float).to(device) |
|
|
|
torch.onnx.export( |
|
model, (hidden_states, position_ids, attention_mask), |
|
f'{folder}/block_{layer_id}.onnx', |
|
verbose=False, |
|
input_names=['input_states', 'position_ids', 'attention_mask'], |
|
output_names=['hidden_states', 'past_k', 'past_v'], |
|
do_constant_folding=True, |
|
opset_version=15) |
|
|
|
|
|
def convert_block_cache(layer_id): |
|
model = QwenBlockCache(layer_id) |
|
hidden_states = torch.randn((1, GUESS_LENGTH, HIDDEN_SIZE)).to(torch.float).to(device) |
|
position_ids = torch.tensor([range(GUESS_LENGTH)], dtype=torch.long).to(device) |
|
attention_mask = torch.ones( |
|
(1, 1, GUESS_LENGTH, SEQ_LENGTH + GUESS_LENGTH)).to(torch.float).to(device) |
|
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).to(torch.float).to(device) |
|
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).to(torch.float).to(device) |
|
|
|
torch.onnx.export( |
|
model, (hidden_states, position_ids, attention_mask, past_k, past_v), |
|
f'{folder}/block_cache_{layer_id}.onnx', |
|
verbose=False, |
|
input_names=[ |
|
'input_states', 'position_ids', 'attention_mask', 'history_k', |
|
'history_v' |
|
], |
|
output_names=['hidden_states', 'past_k', 'past_v'], |
|
do_constant_folding=True, |
|
opset_version=15) |
|
|
|
|
|
def convert_embedding(): |
|
model = Embedding() |
|
input_ids = torch.tensor([range(SEQ_LENGTH)]).to(device) |
|
|
|
torch.onnx.export(model, (input_ids), |
|
f'{folder}/embedding.onnx', |
|
verbose=False, |
|
input_names=['input_ids'], |
|
output_names=['input_embed'], |
|
do_constant_folding=True, |
|
opset_version=15) |
|
|
|
|
|
def convert_lm_head(): |
|
model = LmHead() |
|
hidden_states = torch.randn(GUESS_LENGTH, HIDDEN_SIZE).to(device) |
|
|
|
torch.onnx.export(model, (hidden_states), |
|
f'{folder}/lm_head.onnx', |
|
verbose=False, |
|
input_names=['hidden_states'], |
|
output_names=['m_logits'], |
|
do_constant_folding=True, |
|
opset_version=15) |
|
|
|
|
|
def convert_greedy_head(): |
|
model = GreedyHead() |
|
m_logits = torch.randn(GUESS_LENGTH, VOCAB_SIZE) |
|
|
|
torch.onnx.export( |
|
model, (m_logits), |
|
f'{folder}/greedy_head.onnx', |
|
verbose=False, |
|
input_names=['m_logits'], |
|
output_names=['token'], |
|
do_constant_folding=True, |
|
opset_version=15) |
|
|
|
|
|
|
|
def convert_penalty_sample_head(): |
|
model = PenaltySampleHead() |
|
m_logits = torch.randn(GUESS_LENGTH, VOCAB_SIZE) |
|
input_ids = torch.tensor([range(SEQ_LENGTH)]) |
|
top_p = torch.tensor([0.8]) |
|
temperature = torch.tensor([0.98]) |
|
penalty = torch.tensor([0.98]) |
|
|
|
torch.onnx.export( |
|
model, (m_logits, input_ids, top_p, temperature, penalty), |
|
f'{folder}/penalty_sample_head.onnx', |
|
verbose=False, |
|
input_names=[ |
|
'm_logits', 'input_ids', 'top_p', 'temperature', |
|
'penalty' |
|
], |
|
output_names=['probs', 'token'], |
|
do_constant_folding=True, |
|
opset_version=15) |
|
|
|
|
|
|
|
if not os.path.exists(folder): |
|
os.makedirs(folder) |
|
|
|
|
|
print(f'Convert block & block_cache') |
|
for i in tqdm(range(NUM_LAYERS)): |
|
convert_block(i) |
|
convert_block_cache(i) |
|
|
|
print(f'Convert embedding') |
|
convert_embedding() |
|
|
|
print(f'Convert lm_head') |
|
convert_lm_head() |
|
convert_greedy_head() |
|
convert_penalty_sample_head() |
|
|