Create Models/MoE_model.py
Browse files- Models/MoE_model.py +26 -0
Models/MoE_model.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.gating_network import GatingNetwork
|
4 |
+
from models.vision_expert import VisionExpert
|
5 |
+
from models.audio_expert import AudioExpert
|
6 |
+
from models.sensor_expert import SensorExpert
|
7 |
+
|
8 |
+
class MoEModel(nn.Module):
|
9 |
+
def __init__(self, input_dim, num_experts):
|
10 |
+
super(MoEModel, self).__init__()
|
11 |
+
self.gating_network = GatingNetwork(input_dim=input_dim, num_experts=num_experts)
|
12 |
+
self.experts = nn.ModuleList([VisionExpert(), AudioExpert(), SensorExpert()])
|
13 |
+
self.fc_final = nn.Linear(128, 10) # Assuming 10 possible actions
|
14 |
+
|
15 |
+
def forward(self, vision_input, audio_input, sensor_input):
|
16 |
+
vision_features = self.experts[0](vision_input)
|
17 |
+
audio_features = self.experts[1](audio_input)
|
18 |
+
sensor_features = self.experts[2](sensor_input)
|
19 |
+
|
20 |
+
combined_features = torch.cat((vision_features, audio_features, sensor_features), dim=1)
|
21 |
+
gating_weights = self.gating_network(combined_features)
|
22 |
+
|
23 |
+
expert_outputs = torch.stack([expert(combined_features) for expert in self.experts], dim=1)
|
24 |
+
final_output = torch.einsum('ij,ijk->ik', gating_weights, expert_outputs)
|
25 |
+
|
26 |
+
return self.fc_final(final_output)
|