lfqa1 / lfqa_server /main.py
Achyut Tiwari
Add files via upload
8c7ad44 unverified
raw
history blame
4.63 kB
import torch
from fastapi import FastAPI, Depends, status
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time
from typing import Dict, List, Optional
import jwt
from decouple import config
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
JWT_SECRET = config("secret")
JWT_ALGORITHM = config("algorithm")
app = FastAPI()
app.ready = False
device = ("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_lfqa')
model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_lfqa').to(device)
_ = model.eval()
class JWTBearer(HTTPBearer):
def __init__(self, auto_error: bool = True):
super(JWTBearer, self).__init__(auto_error=auto_error)
async def __call__(self, request: Request):
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
if credentials:
if not credentials.scheme == "Bearer":
raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
if not self.verify_jwt(credentials.credentials):
raise HTTPException(status_code=403, detail="Invalid token or expired token.")
return credentials.credentials
else:
raise HTTPException(status_code=403, detail="Invalid authorization code.")
def verify_jwt(self, jwtoken: str) -> bool:
isTokenValid: bool = False
try:
payload = decodeJWT(jwtoken)
except:
payload = None
if payload:
isTokenValid = True
return isTokenValid
def token_response(token: str):
return {
"access_token": token
}
def signJWT(user_id: str) -> Dict[str, str]:
payload = {
"user_id": user_id,
"expires": time.time() + 6000
}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token_response(token)
def decodeJWT(token: str) -> dict:
try:
decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return decoded_token if decoded_token["expires"] >= time.time() else None
except:
return {}
class LFQAParameters(BaseModel):
min_length: int = 50
max_length: int = 250
do_sample: bool = False
early_stopping: bool = True
num_beams: int = 8
temperature: float = 1.0
top_k: float = None
top_p: float = None
no_repeat_ngram_size: int = 3
num_return_sequences: int = 1
class InferencePayload(BaseModel):
model_input: str
parameters: Optional[LFQAParameters] = LFQAParameters()
@app.on_event("startup")
def startup():
app.ready = True
@app.get("/healthz")
def healthz():
if app.ready:
return PlainTextResponse("ok")
return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
@app.post("/generate/", dependencies=[Depends(JWTBearer())])
def generate(context: InferencePayload):
model_input = tokenizer(context.model_input, truncation=True, padding=True, return_tensors="pt")
param = context.parameters
generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
attention_mask=model_input["attention_mask"].to(device),
min_length=param.min_length,
max_length=param.max_length,
do_sample=param.do_sample,
early_stopping=param.early_stopping,
num_beams=param.num_beams,
temperature=param.temperature,
top_k=param.top_k,
top_p=param.top_p,
no_repeat_ngram_size=param.no_repeat_ngram_size,
num_return_sequences=param.num_return_sequences)
answers = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
clean_up_tokenization_spaces=True)
results = []
for answer in answers:
results.append({"generated_text": answer})
return results