Embedding: Memory leak on MPS backend
Hi,
I have run your model both using HuggingFace (loading from pretrained) and FlagEmbedding and after running the model for ~3000 sentences, it uses >20GB of my RAM and in the end dies due to OOM. This is running on MPS backend on my mac.
I have used various tricks (gc.collect() as well as torch.mps.empty_cache()) which helps slightly but memory usage keeps going up.
I have not seen this in any other model I have run locally and I see people have complained about it. So do you have any insight into it or any advice to give me? I hear that MPS backend used to leak memory but they were fixed and I have updated my torch to the latest.
Many thanks in advance
Ali
The code:
from transformers import AutoModel, AutoTokenizer
import torch
import os
import gc
os.environ['TOKENIZERS_PARALLELISM'] = '0'
device = torch.device('mps')
# Load the BGE-M3 model and tokenizer
model_name = "BAAI/bge-m3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
def get_embedding(text):
# Tokenize input text
inpu_ = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=8192)
inputs = inpu_.to(device)
# Get model output
with torch.no_grad():
outputs = model(**inputs)
# Use mean pooling for embedding
embeddings = outputs.last_hidden_state.mean(dim=1)
cpu_embeds = embeddings.cpu()
ret = cpu_embeds.numpy()
del embeddings
del inputs
del outputs
del cpu_embeds
del inpu_
gc.collect() # Force garbage collection
torch.mps.empty_cache()
torch.mps.synchronize()
return ret
# Example Farsi text
text = "این یک جمله تستی برای مدل BGE-M3 است که میتواند متون بلند را پردازش کند."
embedding = get_embedding(text)
print("Embedding shape:", embedding.shape)
print("Embedding vector:", embedding)
torch.mps.driver_allocated_memory()
And this one is right out of your docs:
from FlagEmbedding import BGEM3FlagModel
model = BGEM3FlagModel('BAAI/bge-m3',
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
sentences_1 = ["What is BGE M3?", "Defination of BM25"]
sentences_2 = ["BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"]
embeddings_1 = model.encode(sentences_1,
batch_size=12,
max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
)['dense_vecs']
embeddings_2 = model.encode(sentences_2)['dense_vecs']
Then I run the embedding code in a loop on my corpus which gradually consumes more and more memory until dies.