File size: 2,429 Bytes
8020e89
e82ce45
 
 
 
 
 
 
 
90fea7a
 
 
 
 
 
 
 
 
 
 
8020e89
e82ce45
 
91bf914
 
e82ce45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8020e89
 
e82ce45
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from calendar import c
import json
import os
from typing import Any
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

class ConcatCustomPooling(nn.Module):
    def __init__(self, model_name_or_path="BAAI/bge-large-en-v1.5",layers=[
        15,
        16,
        17,
        18,
        19,
        20,
        21,
        22,
        23
    ],max_seq_len=512, **kwargs):
        super().__init__()
        self.layers = layers
        self.base_name = model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
        self.model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5")
        self.model.eval()
        self.max_seq_len = max_seq_len

    
    def tokenize(self, inputs: list[str]):
        return self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")

    

    def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
        model_output = self.model(**features,output_hidden_states=True)
        embeddings = model_output.hidden_states
        layers_embeddings = embeddings[1:] # Remove the first which is the raw embeddings layer
        number_of_layers = len(layers_embeddings)
        if self.layers is None:
            self.layers = list(range(number_of_layers))
        cls_embeddings = torch.stack([torch.nn.functional.normalize(layer[:, 0, :], p=2, dim=1) for layer_idx,layer in enumerate(layers_embeddings) if layer_idx in self.layers], dim=1) 

        batch_size, layer_num, hidden_dim = cls_embeddings.shape

        # Reshape to concatenate the layer_num and hidden_dim dimensions
        cls_embeddings_concat = cls_embeddings.view(batch_size, -1)  
        return {'sentence_embedding':cls_embeddings_concat}

    def get_config_dict(self) -> dict[str, Any]:
        return {"model_name": self.base_name, "layers": self.layers, "max_seq_len": self.max_seq_len}

    
    def get_max_seq_length(self) -> int:
        return self.max_seq_len

    def save(self, save_dir: str, **kwargs) -> None:
        with open(os.path.join(save_dir, "config.json"), "w") as fOut:
            json.dump(self.get_config_dict(), fOut, indent=4)

    @classmethod
    def load(cls,load_dir: str, **kwargs) -> "ConcatCustomPooling":
        with open(os.path.join(load_dir, "config.json")) as fIn:
            config = json.load(fIn)

        return ConcatCustomPooling(**config)