E-slam commited on
Commit
ded6a94
·
verified ·
1 Parent(s): 421ec20

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -4
main.py CHANGED
@@ -3,6 +3,13 @@ import urllib
3
  import json
4
  from fastapi import FastAPI, HTTPException, Query
5
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
@@ -15,12 +22,36 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  @app.get("/e5_embeddings")
21
  def e5_embeddings(query: str = Query(...)):
22
- if query:
23
- return query
 
 
 
24
  else:
25
- return "Hello From HF"
26
-
 
3
  import json
4
  from fastapi import FastAPI, HTTPException, Query
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import torch
8
+ from torch import Tensor
9
+ import torch.nn.functional as F
10
+
11
+ import os
12
+ os.environ['HF_HOME'] = '/'
13
 
14
  app = FastAPI()
15
 
 
22
  allow_headers=["*"],
23
  )
24
 
25
+ model_name = "intfloat/multilingual-e5-large"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ model = AutoModel.from_pretrained(model_name)
28
+
29
+ def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
30
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
31
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
32
+
33
+ def embed_single_text(text: str) -> Tensor:
34
+ tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
35
+ model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').cpu()
36
+
37
+ batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
38
+
39
+ with torch.no_grad():
40
+ outputs = model(**batch_dict)
41
+
42
+ embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
43
+
44
+ embedding = F.normalize(embedding, p=2, dim=1)
45
+
46
+ return embedding
47
 
48
 
49
  @app.get("/e5_embeddings")
50
  def e5_embeddings(query: str = Query(...)):
51
+
52
+ result = embed_single_text(query)
53
+
54
+ if result:
55
+ return json.loads(result)
56
  else:
57
+ raise HTTPException(status_code=500)