|
import torch |
|
import torch.nn as nn |
|
from models.gating_network import GatingNetwork |
|
from models.vision_expert import VisionExpert |
|
from models.audio_expert import AudioExpert |
|
from models.sensor_expert import SensorExpert |
|
|
|
class MoEModel(nn.Module): |
|
def __init__(self, input_dim, num_experts): |
|
super(MoEModel, self).__init__() |
|
self.gating_network = GatingNetwork(input_dim=input_dim, num_experts=num_experts) |
|
self.experts = nn.ModuleList([VisionExpert(), AudioExpert(), SensorExpert()]) |
|
self.fc_final = nn.Linear(128, 10) |
|
|
|
def forward(self, vision_input, audio_input, sensor_input): |
|
vision_features = self.experts[0](vision_input) |
|
audio_features = self.experts[1](audio_input) |
|
sensor_features = self.experts[2](sensor_input) |
|
|
|
combined_features = torch.cat((vision_features, audio_features, sensor_features), dim=1) |
|
gating_weights = self.gating_network(combined_features) |
|
|
|
expert_outputs = torch.stack([expert(combined_features) for expert in self.experts], dim=1) |
|
final_output = torch.einsum('ij,ijk->ik', gating_weights, expert_outputs) |
|
|
|
return self.fc_final(final_output) |
|
|