|
import cloudpickle |
|
import os |
|
import tensorflow as tf |
|
from scraper import scrape_text |
|
from fastapi import FastAPI, Response, Request |
|
from typing import List, Dict |
|
from pydantic import BaseModel, Field |
|
from fastapi.exceptions import RequestValidationError |
|
import uvicorn |
|
import json |
|
import logging |
|
import multiprocessing |
|
from news_classifier import predict_news_classes |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["TF_USE_LEGACY_KERAS"] = "1" |
|
|
|
|
|
def load_model(): |
|
logging.warning('Entering load transformer') |
|
with open("classification_models/label_encoder.bin", "rb") as model_file_obj: |
|
label_encoder = cloudpickle.load(model_file_obj) |
|
|
|
with open("classification_models/calibrated_model.bin", "rb") as model_file_obj: |
|
calibrated_model = cloudpickle.load(model_file_obj) |
|
|
|
tflite_model_path = os.path.join("classification_models", "model.tflite") |
|
calibrated_model.estimator.tflite_model_path = tflite_model_path |
|
logging.warning('Exiting load transformer') |
|
return calibrated_model, label_encoder |
|
|
|
|
|
async def scrape_urls(urls): |
|
logging.warning('Entering scrape_urls()') |
|
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) |
|
|
|
results = [] |
|
for url in urls: |
|
f = pool.apply_async(scrape_text, [url]) |
|
results.append(f) |
|
|
|
scraped_texts = [] |
|
scrape_errors = [] |
|
for f in results: |
|
t, e = f.get(timeout=120) |
|
scraped_texts.append(t) |
|
scrape_errors.append(e) |
|
pool.close() |
|
pool.join() |
|
logging.warning('Exiting scrape_urls()') |
|
return scraped_texts, scrape_errors |
|
|
|
|
|
description = '''API to classify news articles into categories from their URLs.\n |
|
Categories = ASTROLOGY, BUSINESS, EDUCATION, ENTERTAINMENT, HEALTH, NATION, SCIENCE, SPORTS, TECHNOLOGY, WEATHER, WORLD''' |
|
app = FastAPI(title='News Classifier API', |
|
description=description, |
|
version="0.0.1", |
|
contact={ |
|
"name": "Author: KSV Muralidhar", |
|
"url": "https://ksvmuralidhar.in" |
|
}, |
|
license_info={ |
|
"name": "License: MIT", |
|
"identifier": "MIT" |
|
}, |
|
swagger_ui_parameters={"defaultModelsExpandDepth": -1}) |
|
|
|
|
|
class URLList(BaseModel): |
|
urls: List[str] = Field(..., description="List of URLs of news articles to classify") |
|
key: str = Field(..., description="Authentication Key") |
|
|
|
class Categories(BaseModel): |
|
label: str = Field(..., description="category label") |
|
calibrated_prediction_proba: float = Field(..., |
|
description="calibrated prediction probability (confidence)") |
|
|
|
class SuccessfulResponse(BaseModel): |
|
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") |
|
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs") |
|
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL") |
|
category: Categories = Field(..., description="Dict of category label of news articles along with calibrated prediction_proba") |
|
classifier_error: str = Field("", description="Empty string as the response code is 200") |
|
|
|
class AuthenticationError(BaseModel): |
|
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") |
|
scraped_texts: str = Field("", description="Empty string as authentication failed") |
|
scrape_errors: str = Field("", description="Empty string as authentication failed") |
|
category: str = Field("", description="Empty string as authentication failed") |
|
classifier_error: str = Field("Error: Authentication error: Invalid API key.") |
|
|
|
class ClassifierError(BaseModel): |
|
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") |
|
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs") |
|
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL") |
|
category: str = Field("", description="Empty string as classifier encountered an error") |
|
classifier_error: str = Field("Error: Classifier Error with a message describing the error") |
|
|
|
class InputValidationError(BaseModel): |
|
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") |
|
scraped_texts: str = Field("", description="Empty string as validation failed") |
|
scrape_errors: str = Field("", description="Empty string as validation failed") |
|
category: str = Field("", description="Empty string as validation failed") |
|
classifier_error: str = Field("Validation Error with a message describing the error") |
|
|
|
|
|
class NewsClassifierAPIAuthenticationError(Exception): |
|
pass |
|
|
|
class NewsClassifierAPIScrapingError(Exception): |
|
pass |
|
|
|
|
|
def authenticate_key(api_key: str): |
|
if api_key != os.getenv('API_KEY'): |
|
raise NewsClassifierAPIAuthenticationError("Authentication error: Invalid API key.") |
|
|
|
|
|
@app.exception_handler(RequestValidationError) |
|
async def validation_exception_handler(request: Request, exc: RequestValidationError): |
|
urls = request.query_params.getlist("urls") |
|
error_details = exc.errors() |
|
error_messages = [] |
|
for error in error_details: |
|
loc = [*map(str, error['loc'])][-1] |
|
msg = error['msg'] |
|
error_messages.append(f"{loc}: {msg}") |
|
error_message = "; ".join(error_messages) if error_messages else "" |
|
response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'categories': "", 'classifier_error': f'Validation Error: {error_message}'} |
|
json_str = json.dumps(response_json, indent=5) |
|
return Response(content=json_str, media_type='application/json', status_code=422) |
|
|
|
|
|
calibrated_model, label_encoder = load_model() |
|
|
|
@app.post("/classify/", tags=["Classify"], response_model=List[SuccessfulResponse], |
|
responses={ |
|
401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"}, |
|
500: {"model": ClassifierError, "description": "Classifier Error: Returned when the API couldn't classify even a single article"}, |
|
422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't satisfy the data type requirements"} |
|
}) |
|
async def classify(q: URLList): |
|
""" |
|
Get categories of news articles by passing the list of URLs as input. |
|
- **urls**: List of URLs (required) |
|
- **key**: Authentication key (required) |
|
""" |
|
try: |
|
logging.warning("Entering classify()") |
|
urls = "" |
|
scraped_texts = "" |
|
scrape_errors = "" |
|
labels = "" |
|
probs = 0 |
|
request_json = q.json() |
|
request_json = json.loads(request_json) |
|
urls = request_json['urls'] |
|
api_key = request_json['key'] |
|
_ = authenticate_key(api_key) |
|
scraped_texts, scrape_errors = await scrape_urls(urls) |
|
|
|
unique_scraped_texts = [*set(scraped_texts)] |
|
if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1): |
|
raise NewsClassifierAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs") |
|
|
|
labels, probs = await predict_news_classes(urls, scraped_texts, calibrated_model, label_encoder) |
|
label_prob = [{"label": "", "calibrated_prediction_proba": 0} |
|
if t == "" else {"label": l, "calibrated_prediction_proba": p} |
|
for l, p, t in zip(labels, probs, scraped_texts)] |
|
status_code = 200 |
|
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': label_prob, 'classifer_error': ''} |
|
except Exception as e: |
|
status_code = 500 |
|
if e.__class__.__name__ == "NewsClassifierAPIAuthenticationError": |
|
status_code = 401 |
|
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': "", 'classifier_error': f'Error: {e}'} |
|
|
|
json_str = json.dumps(response_json, indent=5) |
|
return Response(content=json_str, media_type='application/json', status_code=status_code) |
|
|
|
|
|
if __name__ == '__main__': |
|
uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3) |
|
|