# modified from https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py import os import torch from torch import nn, Tensor from transformers import AutoModel, AutoConfig from huggingface_hub import snapshot_download from typing import Dict class BGEM3InferenceModel(nn.Module): def __init__( self, model_name: str = "BAAI/bge-m3", colbert_dim: int = -1, ) -> None: super().__init__() model_name = snapshot_download( repo_id=model_name, allow_patterns=[ "pytorch_model.bin", "config.json", ], ) self.config = AutoConfig.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) def dense_embedding(self, last_hidden_state: Tensor) -> Tensor: return last_hidden_state[:, 0] def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]: with torch.no_grad(): last_hidden_state = self.model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ).last_hidden_state output = {} dense_vecs = self.dense_embedding(last_hidden_state) output["dense_vecs"] = dense_vecs # torch.nn.functional.normalize(dense_vecs, dim=-1) return output