Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from enum import Enum | |
import os | |
from sentence_transformers import SentenceTransformer | |
model = SentenceTransformer( | |
"dunzhang/stella_en_400M_v5", | |
trust_remote_code=True, | |
device="cpu", | |
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False} | |
) | |
class Enum(str, Enum): | |
s2p_query = "s2p_query" # sentence-to-sentence | |
s2s_query = "s2s_query" # sentence-to-passage, Q&A | |
class Embedding(BaseModel): | |
input: list[str] | |
embedding_type: Enum = None | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["POST"], | |
allow_headers=["Authorization"] | |
) | |
def parse(data): | |
result = [] | |
for dimension in data: | |
temp = [] | |
for val in dimension: | |
temp.append(round(val, 8)) | |
result.append(temp) | |
return result | |
async def get_embedding(embedding: Embedding, req: Request): | |
if model == None: | |
raise HTTPException(status_code=400, detail="Model load failed.") | |
if embedding.embedding_type == None: | |
data = model.encode(embedding.input).tolist() | |
return parse(data) | |
else: | |
data = model.encode(embedding.input, prompt_name=embedding.embedding_type).tolist() | |
return parse(data) | |