Hansimov commited on
Commit
26cd288
1 Parent(s): 573e11b

:gem: [Feature] Enable onnx model for embedder

Browse files
Files changed (1) hide show
  1. transforms/embed.py +70 -7
transforms/embed.py CHANGED
@@ -1,10 +1,16 @@
1
  import os
2
 
 
 
 
 
3
  from typing import Union
4
 
5
- from tclogger import logger
6
- from transformers import AutoModel
7
  from numpy.linalg import norm
 
 
 
8
 
9
  from configs.envs import ENVS
10
  from configs.constants import AVAILABLE_MODELS
@@ -18,6 +24,59 @@ def cosine_similarity(a, b):
18
  return (a @ b.T) / (norm(a) * norm(b))
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class JinaAIEmbedder:
22
  def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
23
  self.model_name = model_name
@@ -44,9 +103,13 @@ class JinaAIEmbedder:
44
 
45
 
46
  if __name__ == "__main__":
47
- embedder = JinaAIEmbedder()
48
- text = ["How is the weather today?", "今天天气怎么样?"]
49
- # text = "How is the weather today?"
50
- embeddings = embedder.encode(text)
 
 
51
  logger.success(embeddings)
52
- # print(cosine_similarity(embeddings[0], embeddings[1]))
 
 
 
1
  import os
2
 
3
+ import numpy as np
4
+ import torch
5
+
6
+ from pathlib import Path
7
  from typing import Union
8
 
9
+ from huggingface_hub import hf_hub_download
 
10
  from numpy.linalg import norm
11
+ from onnxruntime import InferenceSession
12
+ from tclogger import logger
13
+ from transformers import AutoTokenizer, AutoModel
14
 
15
  from configs.envs import ENVS
16
  from configs.constants import AVAILABLE_MODELS
 
24
  return (a @ b.T) / (norm(a) * norm(b))
25
 
26
 
27
+ class JinaAIOnnxEmbedder:
28
+ """https://huggingface.co/jinaai/jina-embeddings-v2-base-zh/discussions/6#65bc55a854ab5eb7b6300893"""
29
+
30
+ def __init__(self):
31
+ self.repo_name = "jinaai/jina-embeddings-v2-base-zh"
32
+ self.download_model()
33
+ self.load_model()
34
+
35
+ def download_model(self):
36
+ self.onnx_folder = Path(__file__).parent
37
+ self.onnx_filename = "onnx/model_quantized.onnx"
38
+ self.onnx_path = self.onnx_folder / self.onnx_filename
39
+ if not self.onnx_path.exists():
40
+ logger.note("> Downloading ONNX model")
41
+ hf_hub_download(
42
+ repo_id=self.repo_name,
43
+ filename=self.onnx_filename,
44
+ local_dir=self.onnx_folder,
45
+ local_dir_use_symlinks=False,
46
+ )
47
+ logger.success(f"+ ONNX model downloaded: {self.onnx_path}")
48
+ else:
49
+ logger.success(f"+ ONNX model loaded: {self.onnx_path}")
50
+
51
+ def load_model(self):
52
+ self.tokenizer = AutoTokenizer.from_pretrained(
53
+ self.repo_name, trust_remote_code=True
54
+ )
55
+ self.session = InferenceSession(self.onnx_path)
56
+
57
+ def mean_pooling(self, model_output, attention_mask):
58
+ token_embeddings = model_output
59
+ input_mask_expanded = (
60
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
61
+ )
62
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
63
+ input_mask_expanded.sum(1), min=1e-9
64
+ )
65
+
66
+ def encode(self, text: str):
67
+ inputs = self.tokenizer(text, return_tensors="np")
68
+ inputs = {
69
+ name: np.array(tensor, dtype=np.int64) for name, tensor in inputs.items()
70
+ }
71
+ outputs = self.session.run(
72
+ output_names=["last_hidden_state"], input_feed=dict(inputs)
73
+ )
74
+ embeddings = self.mean_pooling(
75
+ torch.from_numpy(outputs[0]), torch.from_numpy(inputs["attention_mask"])
76
+ )
77
+ return embeddings
78
+
79
+
80
  class JinaAIEmbedder:
81
  def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
82
  self.model_name = model_name
 
103
 
104
 
105
  if __name__ == "__main__":
106
+ # embedder = JinaAIEmbedder()
107
+ embedder = JinaAIOnnxEmbedder()
108
+ texts = ["How is the weather today?", "今天天气怎么样?"]
109
+ embeddings = []
110
+ for text in texts:
111
+ embeddings.append(embedder.encode(text))
112
  logger.success(embeddings)
113
+ print(cosine_similarity(embeddings[0], embeddings[1]))
114
+
115
+ # python -m transforms.embed