File size: 1,400 Bytes
283e8f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from torch import nn
import torch.nn.functional as F

class MLPProberBase(nn.Module):
    def __init__(self, d=768, layer='all', num_outputs=87):
        super().__init__()
        
        self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
        
        self.num_layers = len(self.hidden_layer_sizes)

        self.layer = layer

        for i, ld in enumerate(self.hidden_layer_sizes):
            setattr(self, f"hidden_{i}", nn.Linear(d, ld))
            d = ld
        self.output = nn.Linear(d, num_outputs)

        self.n_tranformer_layer = 12
        
        self.init_aggregator()


    def init_aggregator(self):
        """Initialize the aggregator for weighted sum over different layers of features
        """
        if self.layer == "all":
            # use learned weights to aggregate features
            self.aggregator = nn.Parameter(torch.randn((1, self.n_tranformer_layer, 1)))


    def forward(self, x):
        """
        x: (B, L, T, H)
        T=#chunks, can be 1 or several chunks
        """
        
        if self.layer == "all":
            weights = F.softmax(self.aggregator, dim=1)
            x = (x * weights).sum(dim=1)

        for i in range(self.num_layers):
            x = getattr(self, f"hidden_{i}")(x)
            # x = self.dropout(x)
            x = F.relu(x)
        output = self.output(x)
        return output