embeding_api / main.py
Arafath10's picture
Update main.py
aaf0100 verified
raw
history blame
1.25 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cpu")
# Load the model and tokenizer
model = AutoModel.from_pretrained(
"nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True
)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/get_embeding")
async def get_embeding(chunk):
# Tokenize the input text
inputs = tokenizer(chunk, return_tensors="pt")
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
# The embeddings can be found in the 'last_hidden_state'
embeddings = outputs.last_hidden_state
# Optionally, you can average the token embeddings to get a single vector for the sentence
sentence_embedding = torch.mean(embeddings, dim=1)
#print(sentence_embedding)
return sentence_embedding.tolist()