Embedding: Memory leak on MPS backend

#109
by aliostad - opened

Hi,

I have run your model both using HuggingFace (loading from pretrained) and FlagEmbedding and after running the model for ~3000 sentences, it uses >20GB of my RAM and in the end dies due to OOM. This is running on MPS backend on my mac.

I have used various tricks (gc.collect() as well as torch.mps.empty_cache()) which helps slightly but memory usage keeps going up.

I have not seen this in any other model I have run locally and I see people have complained about it. So do you have any insight into it or any advice to give me? I hear that MPS backend used to leak memory but they were fixed and I have updated my torch to the latest.

Many thanks in advance
Ali

The code:

from transformers import AutoModel, AutoTokenizer
import torch
import os
import gc

os.environ['TOKENIZERS_PARALLELISM'] = '0'

device = torch.device('mps')

# Load the BGE-M3 model and tokenizer
model_name = "BAAI/bge-m3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

def get_embedding(text):
    # Tokenize input text
    inpu_ = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=8192)
    inputs = inpu_.to(device)

    # Get model output
    with torch.no_grad():
        outputs = model(**inputs)

    # Use mean pooling for embedding
    embeddings = outputs.last_hidden_state.mean(dim=1)
    cpu_embeds = embeddings.cpu()
    ret = cpu_embeds.numpy()
    del embeddings
    del inputs
    del outputs
    del cpu_embeds
    del inpu_
    gc.collect()  # Force garbage collection
    torch.mps.empty_cache()
    torch.mps.synchronize()
    return ret

# Example Farsi text
text = "این یک جمله تستی برای مدل BGE-M3 است که می‌تواند متون بلند را پردازش کند."
embedding = get_embedding(text)

print("Embedding shape:", embedding.shape)
print("Embedding vector:", embedding)

torch.mps.driver_allocated_memory()

And this one is right out of your docs:

from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel('BAAI/bge-m3',  
                       use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

sentences_1 = ["What is BGE M3?", "Defination of BM25"]
sentences_2 = ["BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.", 
               "BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"]

embeddings_1 = model.encode(sentences_1, 
                            batch_size=12, 
                            max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
                            )['dense_vecs']
embeddings_2 = model.encode(sentences_2)['dense_vecs']

Then I run the embedding code in a loop on my corpus which gradually consumes more and more memory until dies.

Sign up or log in to comment