File size: 1,279 Bytes
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Utils for transformer embeddings."""

import functools
import os
from typing import TYPE_CHECKING, Optional

from ..env import data_path
from ..utils import log

if TYPE_CHECKING:
  from sentence_transformers import SentenceTransformer


def get_model(model_name: str,
              optimal_batch_sizes: dict[str, int] = {}) -> tuple[int, 'SentenceTransformer']:
  """Get a transformer model and the optimal batch size for it."""
  try:
    import torch.backends.mps
    from sentence_transformers import SentenceTransformer
  except ImportError:
    raise ImportError('Could not import the "sentence_transformers" python package. '
                      'Please install it with `pip install sentence-transformers`.')
  preferred_device: Optional[str] = None
  if torch.backends.mps.is_available():
    preferred_device = 'mps'
  elif not torch.backends.mps.is_built():
    log('MPS not available because the current PyTorch install was not built with MPS enabled.')

  @functools.cache
  def _get_model(model_name: str) -> 'SentenceTransformer':
    return SentenceTransformer(
      model_name, device=preferred_device, cache_folder=os.path.join(data_path(), '.cache'))

  batch_size = optimal_batch_sizes[preferred_device or '']
  return batch_size, _get_model(model_name)