thecuong's picture
First commit
4bb4208
raw
history blame
1.23 kB
from typing import List, Literal
from pydantic import BaseModel, Field
from fastapi import FastAPI, APIRouter, Request
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
import uvicorn
# Initialize FastAPI app
app = FastAPI()
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
# Define data model
class PostEmbeddings(BaseModel):
type: Literal['default', 'disease', 'gte'] = Field(default='default')
sentences: List[str]
# Router for embeddings
router = APIRouter(prefix="/retrieval", tags=["retrieval"])
@router.post('/embeddings')
def post_embeddings(request: Request, data: PostEmbeddings):
embeddings = model.encode(data.sentences)
return {"data":{"embeddings": embeddings.tolist()}}
# Include router
app.include_router(router)
# Define main function to run the app
def main():
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
# Run the app if this script is the main module
if __name__ == "__main__":
main()