import uvicorn import os from typing import Union from fastapi import FastAPI from kpe import KPE from fastapi.middleware.cors import CORSMiddleware # from fastapi.middleware.trustedhost import TrustedHostMiddleware from fastapi import APIRouter , Query from sentence_transformers import SentenceTransformer import utils from ranker import get_sorted_keywords from pydantic import BaseModel app = FastAPI( title="AHD Persian KPE", # version=config.settings.VERSION, description="Keyphrase Extraction", openapi_url="/openapi.json", docs_url="/", ) TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt') kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu') ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu') # Sets all CORS enabled origins app.add_middleware( CORSMiddleware, allow_origins=["*"], #str(origin) for origin in config.settings.BACKEND_CORS_ORIGINS allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class KpeParams(BaseModel): text:str count:int=10000 using_ner:bool=True return_sorted:bool=False router = APIRouter() @router.get("/") def home(): return "Welcome to AHD Keyphrase Extraction Service" @router.post("/extract", description="extract keyphrase from persian documents") async def extract(kpe_params: KpeParams): global kpe text = utils.normalize(kpe_params.text) kps = kpe.extract(text, using_ner=kpe_params.using_ner) if kpe_params.return_sorted: kps = get_sorted_keywords(ranker_transformer, text, kps) else: kps = [(kp, 1) for kp in kps] if len(kps) > kpe_params.count: kps = kps[:kpe_params.count] return kps app.include_router(router) if __name__ == "__main__": uvicorn.run("main:app",host="0.0.0.0", port=7201)