Qwen2-Audio-rkllm / audio_encoder_export_onnx.py
happyme531's picture
Upload 20 files
2ef3e1d verified
raw
history blame
3.17 kB
import torch
import torch.nn as nn
from transformers import Qwen2AudioForConditionalGeneration
class Qwen2AudioEncoderWrapper(nn.Module):
"""包装Qwen2Audio的编码器和映射层用于ONNX导出"""
def __init__(self, model):
super().__init__()
self.audio_tower = model.audio_tower
self.projector = model.multi_modal_projector
def forward(self, input_features, feature_attention_mask):
# 计算音频特征长度
audio_feat_lengths = feature_attention_mask.sum(-1)
batch_size, _, max_mel_seq_len = input_features.shape
# 计算序列长度
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
seq_range = torch.arange(0, max_seq_len, device=input_features.device).unsqueeze(0)
seq_range = seq_range.expand(batch_size, max_seq_len)
# 创建attention mask
lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len)
padding_mask = seq_range >= lengths_expand
audio_attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len)
audio_attention_mask = audio_attention_mask.expand(batch_size, 1, max_seq_len, max_seq_len)
audio_attention_mask = audio_attention_mask.float()
audio_attention_mask = audio_attention_mask.masked_fill(audio_attention_mask.bool(), float("-inf"))
# 获取音频特征
audio_outputs = self.audio_tower(input_features, attention_mask=audio_attention_mask)
audio_features = audio_outputs.last_hidden_state
# 投影到文本空间
projected_features = self.projector(audio_features)
return projected_features
def export_qwen2audio_encoder(model, save_path, input_shape=(1, 80, 3000)):
"""
导出Qwen2Audio编码器到ONNX格式
Args:
model: Qwen2AudioForConditionalGeneration模型
save_path: 保存ONNX模型的路径
input_shape: 输入音频特征的形状 (batch_size, n_mels, seq_len)
"""
wrapper = Qwen2AudioEncoderWrapper(model)
wrapper.eval()
# 准备样例输入
batch_size, n_mels, seq_len = input_shape
dummy_input = torch.randn(input_shape)
dummy_mask = torch.ones((batch_size, seq_len))
# 设置动态轴
dynamic_axes = {
'input_features': {0: 'batch_size', 2: 'sequence_length'},
'feature_attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'output': {0: 'batch_size', 1: 'sequence_length'}
}
# 导出ONNX
torch.onnx.export(
wrapper,
(dummy_input, dummy_mask),
save_path,
input_names=['input_features', 'feature_attention_mask'],
output_names=['output'],
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True
)
if __name__ == "__main__":
# 加载模型
model = Qwen2AudioForConditionalGeneration.from_pretrained("../Qwen2-Audio-7B-Instruct/")
model.eval()
# 导出ONNX
export_qwen2audio_encoder(
model,
"audio_encoder.onnx",
input_shape=(1, 128, 3000) # batch_size=1, n_mels=128, seq_len=3000
)