|
import json |
|
import os |
|
from typing import Any |
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
class ConcatCustomPooling(nn.Module): |
|
def __init__(self, model_name_or_path="BAAI/bge-large-en-v1.5",layers=None,max_seq_len=512, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layers = layers |
|
self.base_name = model_name_or_path |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
self.model = AutoModel.from_pretrained(model_name_or_path) |
|
self.model.eval() |
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
def tokenize(self, inputs: list[str]): |
|
return self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt") |
|
|
|
|
|
|
|
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: |
|
model_output = self.model(**features,output_hidden_states=True) |
|
embeddings = model_output.hidden_states |
|
layers_embeddings = embeddings[1:] |
|
number_of_layers = len(layers_embeddings) |
|
if self.layers is None: |
|
self.layers = list(range(number_of_layers)) |
|
cls_embeddings = torch.stack([torch.nn.functional.normalize(layer[:, 0, :], p=2, dim=1) for layer_idx,layer in enumerate(layers_embeddings) if layer_idx in self.layers], dim=1) |
|
|
|
batch_size, layer_num, hidden_dim = cls_embeddings.shape |
|
|
|
|
|
cls_embeddings_concat = cls_embeddings.view(batch_size, -1) |
|
return {'sentence_embedding':cls_embeddings_concat} |
|
|
|
def get_config_dict(self) -> dict[str, Any]: |
|
return {"model_name": self.base_name, "layers": self.layers, "max_seq_len": self.max_seq_len} |
|
|
|
|
|
def get_max_seq_length(self) -> int: |
|
return self.max_seq_len |
|
|
|
def save(self, save_dir: str, **kwargs) -> None: |
|
with open(os.path.join(save_dir, "config.json"), "w") as fOut: |
|
json.dump(self.get_config_dict(), fOut, indent=4) |
|
|
|
def load(self,load_dir: str, **kwargs) -> "ConcatCustomPooling": |
|
with open(os.path.join(load_dir, "config.json")) as fIn: |
|
config = json.load(fIn) |
|
|
|
return ConcatCustomPooling(**config) |