|
from fastapi import FastAPI |
|
from sentence_transformers import CrossEncoder, SentenceTransformer |
|
from sentence_transformers.util import cos_sim |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from typing import List |
|
from pydantic import BaseModel |
|
|
|
app = FastAPI() |
|
|
|
class InputListModel(BaseModel): |
|
keywords: List[str] |
|
contents: List[str] |
|
|
|
class InputModel(BaseModel): |
|
keyword: str |
|
content: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = SentenceTransformer( |
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
|
trust_remote_code=True |
|
) |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|
|
@app.post("/predict") |
|
async def predict(inp: InputModel): |
|
|
|
text_emb = model.encode(inp.content, convert_to_tensor=True) |
|
|
|
summarize = model.encode(inp.keyword, convert_to_tensor=True) |
|
|
|
out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2 |
|
|
|
return {"results":out.tolist()} |
|
|
|
|
|
@app.post("/predict_list") |
|
async def predict_list(inp: InputListModel): |
|
text_emb = model.encode(inp.contents, convert_to_tensor=True) |
|
summarize = model.encode(inp.keywords, convert_to_tensor=True) |
|
|
|
out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2 |
|
|
|
return {"results":out.tolist()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|