E-slam's picture
Update main.py
b8a7810 verified
raw
history blame
1.62 kB
import re
import urllib
import json
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel
import torch
from torch import Tensor
import torch.nn.functional as F
import os
os.environ['HF_HOME'] = '/'
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_name = "intfloat/multilingual-e5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def embed_single_text(text: str) -> Tensor:
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').cpu()
batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**batch_dict)
embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
@app.get("/e5_embeddings")
def e5_embeddings(query: str = Query(...)):
result = embed_single_text([query])
if result is not None:
return result.tolist()
else:
raise HTTPException(status_code=500)