huseinzol05 commited on
Commit
4edacf5
1 Parent(s): b3f6cad

Upload model

Browse files
Files changed (3) hide show
  1. config.json +40 -0
  2. model.safetensors +3 -0
  3. modeling_contrastive.py +60 -0
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "embedding-model-llama-2b-contrastive/checkpoint-8100",
3
+ "architectures": [
4
+ "LlamaModelEmbedding"
5
+ ],
6
+ "attention_bias": false,
7
+ "auto_map": {
8
+ "AutoModel": "modeling_contrastive.LlamaModelEmbedding"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 5120,
14
+ "id2label": {
15
+ "0": "LABEL_0"
16
+ },
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 13824,
19
+ "label2id": {
20
+ "LABEL_0": 0
21
+ },
22
+ "max_position_embeddings": 32768,
23
+ "model_type": "llama",
24
+ "normalized": true,
25
+ "num_attention_heads": 40,
26
+ "num_hidden_layers": 5,
27
+ "num_key_value_heads": 40,
28
+ "pad_token_id": 0,
29
+ "pretraining_tp": 1,
30
+ "rms_norm_eps": 1e-05,
31
+ "rope_scaling": null,
32
+ "rope_theta": 10000.0,
33
+ "sentence_pooling_method": "mean",
34
+ "temperature": 0.02,
35
+ "tie_word_embeddings": false,
36
+ "torch_dtype": "bfloat16",
37
+ "transformers_version": "4.35.2",
38
+ "use_cache": true,
39
+ "vocab_size": 32000
40
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa5fda645b43bf269f5d548ccfe9dddbbc17e98c06901f1d3366ef782ecae451
3
+ size 3515472056
modeling_contrastive.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import LlamaModel, LlamaConfig, LlamaTokenizer
3
+ from typing import Dict
4
+ from transformers.file_utils import ModelOutput
5
+ from typing import List, Optional, Tuple, Union
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch import nn, Tensor
8
+ from dataclasses import dataclass
9
+ from torch import nn
10
+ from typing import Dict
11
+ import torch
12
+ from transformers.file_utils import ModelOutput
13
+ import torch.nn.functional as F
14
+
15
+ COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)
16
+
17
+ @dataclass
18
+ class EncoderOutput(ModelOutput):
19
+ loss: Optional[Tensor] = None
20
+
21
+ class LlamaModelEmbedding(LlamaModel):
22
+ def __init__(self, config: LlamaConfig, **kwargs):
23
+ super().__init__(config, **kwargs)
24
+
25
+ self.dense_layer = nn.Linear(self.config.hidden_size,1536)
26
+
27
+ def sentence_embedding(self, hidden_state, mask):
28
+ if self.config.sentence_pooling_method == 'mean':
29
+ s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
30
+ d = mask.sum(axis=1, keepdim=True).float()
31
+ return s / d
32
+ elif self.config.sentence_pooling_method == 'cls':
33
+ return hidden_state[:,0]
34
+
35
+ def encode(self, features):
36
+ if features is None:
37
+ return None
38
+ psg_out = super().forward(**features,return_dict=True)
39
+ output = self.dense_layer(psg_out.last_hidden_state)
40
+ p_reps = self.sentence_embedding(output, features['attention_mask'])
41
+ if self.config.normalized:
42
+ p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
43
+ return p_reps.contiguous()
44
+
45
+
46
+ def forward(self, query: Dict[str, Tensor] = None,
47
+ passage: Dict[str, Tensor] = None, labels = None, margin = 0.5):
48
+ q_reps = self.encode(query)
49
+ p_reps = self.encode(passage)
50
+
51
+ loss = None
52
+ if labels is not None:
53
+ distances = COSINE_DISTANCE(q_reps, p_reps)
54
+ losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(margin - distances).pow(2))
55
+ loss = losses.mean()
56
+
57
+ return EncoderOutput(
58
+ loss=loss,
59
+ )
60
+