import torch import torch.nn as nn from transformers import Qwen2AudioForConditionalGeneration class Qwen2AudioEncoderWrapper(nn.Module): """包装Qwen2Audio的编码器和映射层用于ONNX导出""" def __init__(self, model): super().__init__() self.audio_tower = model.audio_tower self.projector = model.multi_modal_projector def forward(self, input_features, feature_attention_mask): # 计算音频特征长度 audio_feat_lengths = feature_attention_mask.sum(-1) batch_size, _, max_mel_seq_len = input_features.shape # 计算序列长度 max_seq_len = (max_mel_seq_len - 2) // 2 + 1 seq_range = torch.arange(0, max_seq_len, device=input_features.device).unsqueeze(0) seq_range = seq_range.expand(batch_size, max_seq_len) # 创建attention mask lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len) padding_mask = seq_range >= lengths_expand audio_attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len) audio_attention_mask = audio_attention_mask.expand(batch_size, 1, max_seq_len, max_seq_len) audio_attention_mask = audio_attention_mask.float() audio_attention_mask = audio_attention_mask.masked_fill(audio_attention_mask.bool(), float("-inf")) # 获取音频特征 audio_outputs = self.audio_tower(input_features, attention_mask=audio_attention_mask) audio_features = audio_outputs.last_hidden_state # 投影到文本空间 projected_features = self.projector(audio_features) return projected_features def export_qwen2audio_encoder(model, save_path, input_shape=(1, 80, 3000)): """ 导出Qwen2Audio编码器到ONNX格式 Args: model: Qwen2AudioForConditionalGeneration模型 save_path: 保存ONNX模型的路径 input_shape: 输入音频特征的形状 (batch_size, n_mels, seq_len) """ wrapper = Qwen2AudioEncoderWrapper(model) wrapper.eval() # 准备样例输入 batch_size, n_mels, seq_len = input_shape dummy_input = torch.randn(input_shape) dummy_mask = torch.ones((batch_size, seq_len)) # 设置动态轴 dynamic_axes = { 'input_features': {0: 'batch_size', 2: 'sequence_length'}, 'feature_attention_mask': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size', 1: 'sequence_length'} } # 导出ONNX torch.onnx.export( wrapper, (dummy_input, dummy_mask), save_path, input_names=['input_features', 'feature_attention_mask'], output_names=['output'], dynamic_axes=dynamic_axes, opset_version=17, do_constant_folding=True ) if __name__ == "__main__": # 加载模型 model = Qwen2AudioForConditionalGeneration.from_pretrained("../Qwen2-Audio-7B-Instruct/") model.eval() # 导出ONNX export_qwen2audio_encoder( model, "audio_encoder.onnx", input_shape=(1, 128, 3000) # batch_size=1, n_mels=128, seq_len=3000 )