import torch import torch.nn as nn from models.gating_network import GatingNetwork from models.vision_expert import VisionExpert from models.audio_expert import AudioExpert from models.sensor_expert import SensorExpert class MoEModel(nn.Module): def __init__(self, input_dim, num_experts): super(MoEModel, self).__init__() self.gating_network = GatingNetwork(input_dim=input_dim, num_experts=num_experts) self.experts = nn.ModuleList([VisionExpert(), AudioExpert(), SensorExpert()]) self.fc_final = nn.Linear(128, 10) # Assuming 10 possible actions def forward(self, vision_input, audio_input, sensor_input): vision_features = self.experts[0](vision_input) audio_features = self.experts[1](audio_input) sensor_features = self.experts[2](sensor_input) combined_features = torch.cat((vision_features, audio_features, sensor_features), dim=1) gating_weights = self.gating_network(combined_features) expert_outputs = torch.stack([expert(combined_features) for expert in self.experts], dim=1) final_output = torch.einsum('ij,ijk->ik', gating_weights, expert_outputs) return self.fc_final(final_output)