Spaces:
Sleeping
Sleeping
import re | |
import urllib | |
import json | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from torch import Tensor | |
import torch.nn.functional as F | |
import os | |
os.environ['HF_HOME'] = '/' | |
app = FastAPI() | |
# Enable CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
model_name = "intfloat/multilingual-e5-large" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def embed_single_text(text: str) -> Tensor: | |
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large') | |
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').cpu() | |
batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
outputs = model(**batch_dict) | |
embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) | |
embedding = F.normalize(embedding, p=2, dim=1) | |
return embedding | |
def e5_embeddings(query: str = Query(...)): | |
result = embed_single_text([query]) | |
if result is not None: | |
return result.tolist() | |
else: | |
raise HTTPException(status_code=500) |