from calendar import c import json import os from typing import Any import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer class ConcatCustomPooling(nn.Module): def __init__(self, model_name_or_path="BAAI/bge-large-en-v1.5",layers=[ 15, 16, 17, 18, 19, 20, 21, 22, 23 ],max_seq_len=512, **kwargs): super().__init__() self.layers = layers self.base_name = model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5") self.model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5") self.model.eval() self.max_seq_len = max_seq_len def tokenize(self, inputs: list[str]): return self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt") def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: model_output = self.model(**features,output_hidden_states=True) embeddings = model_output.hidden_states layers_embeddings = embeddings[1:] # Remove the first which is the raw embeddings layer number_of_layers = len(layers_embeddings) if self.layers is None: self.layers = list(range(number_of_layers)) cls_embeddings = torch.stack([torch.nn.functional.normalize(layer[:, 0, :], p=2, dim=1) for layer_idx,layer in enumerate(layers_embeddings) if layer_idx in self.layers], dim=1) batch_size, layer_num, hidden_dim = cls_embeddings.shape # Reshape to concatenate the layer_num and hidden_dim dimensions cls_embeddings_concat = cls_embeddings.view(batch_size, -1) return {'sentence_embedding':cls_embeddings_concat} def get_config_dict(self) -> dict[str, Any]: return {"model_name": self.base_name, "layers": self.layers, "max_seq_len": self.max_seq_len} def get_max_seq_length(self) -> int: return self.max_seq_len def save(self, save_dir: str, **kwargs) -> None: with open(os.path.join(save_dir, "config.json"), "w") as fOut: json.dump(self.get_config_dict(), fOut, indent=4) @classmethod def load(cls,load_dir: str, **kwargs) -> "ConcatCustomPooling": with open(os.path.join(load_dir, "config.json")) as fIn: config = json.load(fIn) return ConcatCustomPooling(**config)