ibrim commited on
Commit
75db504
1 Parent(s): cff6a73

Upload modules.py

Browse files
Files changed (1) hide show
  1. modules.py +70 -0
modules.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import timm
4
+ from transformers import DistilBertModel, DistilBertConfig
5
+ import config as CFG
6
+
7
+
8
+ class ImageEncoder(nn.Module):
9
+ """
10
+ Encode images to a fixed size vector
11
+ """
12
+
13
+ def __init__(
14
+ self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
15
+ ):
16
+ super().__init__()
17
+ self.model = timm.create_model(
18
+ model_name, pretrained, num_classes=0, global_pool="avg"
19
+ )
20
+ for p in self.model.parameters():
21
+ p.requires_grad = trainable
22
+
23
+ def forward(self, x):
24
+ return self.model(x)
25
+
26
+
27
+ class TextEncoder(nn.Module):
28
+ def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
29
+ super().__init__()
30
+ if pretrained:
31
+ self.model = DistilBertModel.from_pretrained(model_name)
32
+ else:
33
+ self.model = DistilBertModel(config=DistilBertConfig())
34
+
35
+ for p in self.model.parameters():
36
+ p.requires_grad = trainable
37
+
38
+ # we are using the CLS token hidden representation as the sentence's embedding
39
+ self.target_token_idx = 0
40
+
41
+ def forward(self, input_ids, attention_mask):
42
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
43
+ last_hidden_state = output.last_hidden_state
44
+ return last_hidden_state[:, self.target_token_idx, :]
45
+
46
+
47
+
48
+ class ProjectionHead(nn.Module):
49
+ def __init__(
50
+ self,
51
+ embedding_dim,
52
+ projection_dim=CFG.projection_dim,
53
+ dropout=CFG.dropout
54
+ ):
55
+ super().__init__()
56
+ self.projection = nn.Linear(embedding_dim, projection_dim)
57
+ self.gelu = nn.GELU()
58
+ self.fc = nn.Linear(projection_dim, projection_dim)
59
+ self.dropout = nn.Dropout(dropout)
60
+ self.layer_norm = nn.LayerNorm(projection_dim)
61
+
62
+ def forward(self, x):
63
+ projected = self.projection(x)
64
+ x = self.gelu(projected)
65
+ x = self.fc(x)
66
+ x = self.dropout(x)
67
+ x = x + projected
68
+ x = self.layer_norm(x)
69
+ return x
70
+