Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -3,6 +3,13 @@ import urllib
|
|
3 |
import json
|
4 |
from fastapi import FastAPI, HTTPException, Query
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
app = FastAPI()
|
8 |
|
@@ -15,12 +22,36 @@ app.add_middleware(
|
|
15 |
allow_headers=["*"],
|
16 |
)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
@app.get("/e5_embeddings")
|
21 |
def e5_embeddings(query: str = Query(...)):
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
24 |
else:
|
25 |
-
|
26 |
-
|
|
|
3 |
import json
|
4 |
from fastapi import FastAPI, HTTPException, Query
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from transformers import AutoTokenizer, AutoModel
|
7 |
+
import torch
|
8 |
+
from torch import Tensor
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import os
|
12 |
+
os.environ['HF_HOME'] = '/'
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
|
|
22 |
allow_headers=["*"],
|
23 |
)
|
24 |
|
25 |
+
model_name = "intfloat/multilingual-e5-large"
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
27 |
+
model = AutoModel.from_pretrained(model_name)
|
28 |
+
|
29 |
+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
30 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
31 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
32 |
+
|
33 |
+
def embed_single_text(text: str) -> Tensor:
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
|
35 |
+
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').cpu()
|
36 |
+
|
37 |
+
batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = model(**batch_dict)
|
41 |
+
|
42 |
+
embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
43 |
+
|
44 |
+
embedding = F.normalize(embedding, p=2, dim=1)
|
45 |
+
|
46 |
+
return embedding
|
47 |
|
48 |
|
49 |
@app.get("/e5_embeddings")
|
50 |
def e5_embeddings(query: str = Query(...)):
|
51 |
+
|
52 |
+
result = embed_single_text(query)
|
53 |
+
|
54 |
+
if result:
|
55 |
+
return json.loads(result)
|
56 |
else:
|
57 |
+
raise HTTPException(status_code=500)
|
|