|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers import Qwen2AudioForConditionalGeneration |
|
|
|
class Qwen2AudioEncoderWrapper(nn.Module): |
|
"""包装Qwen2Audio的编码器和映射层用于ONNX导出""" |
|
|
|
def __init__(self, model): |
|
super().__init__() |
|
self.audio_tower = model.audio_tower |
|
self.projector = model.multi_modal_projector |
|
|
|
def forward(self, input_features, feature_attention_mask): |
|
|
|
audio_feat_lengths = feature_attention_mask.sum(-1) |
|
batch_size, _, max_mel_seq_len = input_features.shape |
|
|
|
|
|
max_seq_len = (max_mel_seq_len - 2) // 2 + 1 |
|
seq_range = torch.arange(0, max_seq_len, device=input_features.device).unsqueeze(0) |
|
seq_range = seq_range.expand(batch_size, max_seq_len) |
|
|
|
|
|
lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len) |
|
padding_mask = seq_range >= lengths_expand |
|
audio_attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len) |
|
audio_attention_mask = audio_attention_mask.expand(batch_size, 1, max_seq_len, max_seq_len) |
|
audio_attention_mask = audio_attention_mask.float() |
|
audio_attention_mask = audio_attention_mask.masked_fill(audio_attention_mask.bool(), float("-inf")) |
|
|
|
|
|
audio_outputs = self.audio_tower(input_features, attention_mask=audio_attention_mask) |
|
audio_features = audio_outputs.last_hidden_state |
|
|
|
|
|
projected_features = self.projector(audio_features) |
|
|
|
return projected_features |
|
|
|
def export_qwen2audio_encoder(model, save_path, input_shape=(1, 80, 3000)): |
|
""" |
|
导出Qwen2Audio编码器到ONNX格式 |
|
|
|
Args: |
|
model: Qwen2AudioForConditionalGeneration模型 |
|
save_path: 保存ONNX模型的路径 |
|
input_shape: 输入音频特征的形状 (batch_size, n_mels, seq_len) |
|
""" |
|
|
|
wrapper = Qwen2AudioEncoderWrapper(model) |
|
wrapper.eval() |
|
|
|
|
|
batch_size, n_mels, seq_len = input_shape |
|
dummy_input = torch.randn(input_shape) |
|
dummy_mask = torch.ones((batch_size, seq_len)) |
|
|
|
|
|
dynamic_axes = { |
|
'input_features': {0: 'batch_size', 2: 'sequence_length'}, |
|
'feature_attention_mask': {0: 'batch_size', 1: 'sequence_length'}, |
|
'output': {0: 'batch_size', 1: 'sequence_length'} |
|
} |
|
|
|
|
|
torch.onnx.export( |
|
wrapper, |
|
(dummy_input, dummy_mask), |
|
save_path, |
|
input_names=['input_features', 'feature_attention_mask'], |
|
output_names=['output'], |
|
dynamic_axes=dynamic_axes, |
|
opset_version=17, |
|
do_constant_folding=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
model = Qwen2AudioForConditionalGeneration.from_pretrained("../Qwen2-Audio-7B-Instruct/") |
|
model.eval() |
|
|
|
|
|
export_qwen2audio_encoder( |
|
model, |
|
"audio_encoder.onnx", |
|
input_shape=(1, 128, 3000) |
|
) |