Spaces:
Runtime error
Runtime error
import uvicorn | |
import os | |
from typing import Union | |
from fastapi import FastAPI | |
from kpe import KPE | |
from fastapi.middleware.cors import CORSMiddleware | |
# from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
from fastapi import APIRouter , Query | |
from sentence_transformers import SentenceTransformer | |
import utils | |
from ranker import get_sorted_keywords | |
from pydantic import BaseModel | |
app = FastAPI( | |
title="AHD Persian KPE", | |
# version=config.settings.VERSION, | |
description="Keyphrase Extraction", | |
openapi_url="/openapi.json", | |
docs_url="/", | |
) | |
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt') | |
kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu') | |
ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu') | |
# Sets all CORS enabled origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], #str(origin) for origin in config.settings.BACKEND_CORS_ORIGINS | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class KpeParams(BaseModel): | |
text:str | |
count:int=10000 | |
using_ner:bool=True | |
return_sorted:bool=False | |
router = APIRouter() | |
def home(): | |
return "Welcome to AHD Keyphrase Extraction Service" | |
async def extract(kpe_params: KpeParams): | |
global kpe | |
text = utils.normalize(kpe_params.text) | |
kps = kpe.extract(text, using_ner=kpe_params.using_ner) | |
if kpe_params.return_sorted: | |
kps = get_sorted_keywords(ranker_transformer, text, kps) | |
else: | |
kps = [(kp, 1) for kp in kps] | |
if len(kps) > kpe_params.count: | |
kps = kps[:kpe_params.count] | |
return kps | |
app.include_router(router) | |
if __name__ == "__main__": | |
uvicorn.run("main:app",host="0.0.0.0", port=7201) | |