bge-m3-onnx / bgem3_model.py
skshreyas714's picture
Upload 2 files
3975d0a verified
# modified from https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py
import os
import torch
from torch import nn, Tensor
from transformers import AutoModel, AutoConfig
from huggingface_hub import snapshot_download
from typing import Dict
class BGEM3InferenceModel(nn.Module):
def __init__(
self,
model_name: str = "BAAI/bge-m3",
colbert_dim: int = -1,
) -> None:
super().__init__()
model_name = snapshot_download(
repo_id=model_name,
allow_patterns=[
"pytorch_model.bin",
"config.json",
],
)
self.config = AutoConfig.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def dense_embedding(self, last_hidden_state: Tensor) -> Tensor:
return last_hidden_state[:, 0]
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]:
with torch.no_grad():
last_hidden_state = self.model(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
).last_hidden_state
output = {}
dense_vecs = self.dense_embedding(last_hidden_state)
output["dense_vecs"] = dense_vecs # torch.nn.functional.normalize(dense_vecs, dim=-1)
return output