bge_large_en_v1.5_custom_pooling / bge_custom_impl.py
Tomor0720's picture
Upload bge_custom_impl.py with huggingface_hub
90fea7a verified
from calendar import c
import json
import os
from typing import Any
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class ConcatCustomPooling(nn.Module):
def __init__(self, model_name_or_path="BAAI/bge-large-en-v1.5",layers=[
15,
16,
17,
18,
19,
20,
21,
22,
23
],max_seq_len=512, **kwargs):
super().__init__()
self.layers = layers
self.base_name = model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
self.model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5")
self.model.eval()
self.max_seq_len = max_seq_len
def tokenize(self, inputs: list[str]):
return self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
model_output = self.model(**features,output_hidden_states=True)
embeddings = model_output.hidden_states
layers_embeddings = embeddings[1:] # Remove the first which is the raw embeddings layer
number_of_layers = len(layers_embeddings)
if self.layers is None:
self.layers = list(range(number_of_layers))
cls_embeddings = torch.stack([torch.nn.functional.normalize(layer[:, 0, :], p=2, dim=1) for layer_idx,layer in enumerate(layers_embeddings) if layer_idx in self.layers], dim=1)
batch_size, layer_num, hidden_dim = cls_embeddings.shape
# Reshape to concatenate the layer_num and hidden_dim dimensions
cls_embeddings_concat = cls_embeddings.view(batch_size, -1)
return {'sentence_embedding':cls_embeddings_concat}
def get_config_dict(self) -> dict[str, Any]:
return {"model_name": self.base_name, "layers": self.layers, "max_seq_len": self.max_seq_len}
def get_max_seq_length(self) -> int:
return self.max_seq_len
def save(self, save_dir: str, **kwargs) -> None:
with open(os.path.join(save_dir, "config.json"), "w") as fOut:
json.dump(self.get_config_dict(), fOut, indent=4)
@classmethod
def load(cls,load_dir: str, **kwargs) -> "ConcatCustomPooling":
with open(os.path.join(load_dir, "config.json")) as fIn:
config = json.load(fIn)
return ConcatCustomPooling(**config)