File size: 709 Bytes
16188ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn

class Adapter(nn.Module):
    def __init__(self,input_dim:int, hidden_dim: int) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.layerNorm = nn.LayerNorm(input_dim)
        self.down_proj = nn.Linear(input_dim,hidden_dim,False)
        self.up_proj = nn.Linear(hidden_dim,input_dim,False)

    def forward(self,x):
        '''

        :param x: N,L,D
        :return:  N,L,D
        '''
        output = x
        x = self.layerNorm(x)
        x = self.down_proj(x)
        x = nn.functional.relu(x)
        x = self.up_proj(x)
        output = output + x # residual connection
        return output