import os import numpy as np import torch from pathlib import Path from typing import Union from huggingface_hub import hf_hub_download from numpy.linalg import norm from onnxruntime import InferenceSession from tclogger import logger from transformers import AutoTokenizer, AutoModel from configs.envs import ENVS from configs.constants import AVAILABLE_MODELS if ENVS["HF_ENDPOINT"]: os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"] os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"] def cosine_similarity(a, b): return (a @ b.T) / (norm(a) * norm(b)) class JinaAIOnnxEmbedder: """https://huggingface.co/jinaai/jina-embeddings-v2-base-zh/discussions/6#65bc55a854ab5eb7b6300893""" def __init__(self): self.repo_name = "jinaai/jina-embeddings-v2-base-zh" self.download_model() self.load_model() def download_model(self): self.onnx_folder = Path(__file__).parents[2] / ".cache" self.onnx_folder.mkdir(parents=True, exist_ok=True) self.onnx_filename = "onnx/model_quantized.onnx" self.onnx_path = self.onnx_folder / self.onnx_filename if not self.onnx_path.exists(): logger.note("> Downloading ONNX model") hf_hub_download( repo_id=self.repo_name, filename=self.onnx_filename, local_dir=self.onnx_folder, local_dir_use_symlinks=False, ) logger.success(f"+ ONNX model downloaded: {self.onnx_path}") else: logger.success(f"+ ONNX model loaded: {self.onnx_path}") def load_model(self): self.tokenizer = AutoTokenizer.from_pretrained( self.repo_name, trust_remote_code=True ) self.session = InferenceSession(self.onnx_path) def mean_pooling(self, model_output, attention_mask): token_embeddings = model_output input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def encode(self, text: str): inputs = self.tokenizer(text, return_tensors="np") inputs = { name: np.array(tensor, dtype=np.int64) for name, tensor in inputs.items() } outputs = self.session.run( output_names=["last_hidden_state"], input_feed=dict(inputs) ) embeddings = self.mean_pooling( torch.from_numpy(outputs[0]), torch.from_numpy(inputs["attention_mask"]) ) return embeddings class JinaAIEmbedder: def __init__(self, model_name: str = AVAILABLE_MODELS[0]): self.model_name = model_name self.load_model() def check_model_name(self): if self.model_name not in AVAILABLE_MODELS: self.model_name = AVAILABLE_MODELS[0] return True def load_model(self): self.check_model_name() self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True) def switch_model(self, model_name: str): if model_name != self.model_name: self.model_name = model_name self.load_model() def encode(self, text: Union[str, list[str]]): if isinstance(text, str): text = [text] return self.model.encode(text) if __name__ == "__main__": # embedder = JinaAIEmbedder() embedder = JinaAIOnnxEmbedder() texts = ["How is the weather today?", "今天天气怎么样?"] embeddings = [] for text in texts: embeddings.append(embedder.encode(text)) logger.success(embeddings) print(cosine_similarity(embeddings[0], embeddings[1])) # python -m transforms.embed