Spaces:
Sleeping
Sleeping
File size: 8,496 Bytes
32ded02 d4a72b0 32ded02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import cloudpickle
import os
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response, Request
from typing import List, Dict
from pydantic import BaseModel, Field
from fastapi.exceptions import RequestValidationError
import uvicorn
import json
import logging
import multiprocessing
from news_classifier import predict_news_classes
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TF_USE_LEGACY_KERAS"] = "1"
def load_model():
logging.warning('Entering load transformer')
with open("classification_models/label_encoder.bin", "rb") as model_file_obj:
label_encoder = cloudpickle.load(model_file_obj)
with open("classification_models/calibrated_model.bin", "rb") as model_file_obj:
calibrated_model = cloudpickle.load(model_file_obj)
tflite_model_path = os.path.join("classification_models", "model.tflite")
calibrated_model.estimator.tflite_model_path = tflite_model_path
logging.warning('Exiting load transformer')
return calibrated_model, label_encoder
async def scrape_urls(urls):
logging.warning('Entering scrape_urls()')
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
results = []
for url in urls:
f = pool.apply_async(scrape_text, [url]) # asynchronously scraping text
results.append(f) # appending result to results
scraped_texts = []
scrape_errors = []
for f in results:
t, e = f.get(timeout=120)
scraped_texts.append(t)
scrape_errors.append(e)
pool.close()
pool.join()
logging.warning('Exiting scrape_urls()')
return scraped_texts, scrape_errors
description = '''API to classify news articles into categories from their URLs.\n
Categories = ASTROLOGY, BUSINESS, EDUCATION, ENTERTAINMENT, HEALTH, NATION, SCIENCE, SPORTS, TECHNOLOGY, WEATHER, WORLD'''
app = FastAPI(title='News Classifier API',
description=description,
version="0.0.1",
contact={
"name": "Author: KSV Muralidhar",
"url": "https://ksvmuralidhar.in"
},
license_info={
"name": "License: MIT",
"identifier": "MIT"
},
swagger_ui_parameters={"defaultModelsExpandDepth": -1})
class URLList(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles to classify")
key: str = Field(..., description="Authentication Key")
class Categories(BaseModel):
label: str = Field(..., description="category label")
calibrated_prediction_proba: float = Field(...,
description="calibrated prediction probability (confidence)")
class SuccessfulResponse(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
category: Categories = Field(..., description="Dict of category label of news articles along with calibrated prediction_proba")
classifier_error: str = Field("", description="Empty string as the response code is 200")
class AuthenticationError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: str = Field("", description="Empty string as authentication failed")
scrape_errors: str = Field("", description="Empty string as authentication failed")
category: str = Field("", description="Empty string as authentication failed")
classifier_error: str = Field("Error: Authentication error: Invalid API key.")
class ClassifierError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
category: str = Field("", description="Empty string as classifier encountered an error")
classifier_error: str = Field("Error: Classifier Error with a message describing the error")
class InputValidationError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: str = Field("", description="Empty string as validation failed")
scrape_errors: str = Field("", description="Empty string as validation failed")
category: str = Field("", description="Empty string as validation failed")
classifier_error: str = Field("Validation Error with a message describing the error")
class NewsClassifierAPIAuthenticationError(Exception):
pass
class NewsClassifierAPIScrapingError(Exception):
pass
def authenticate_key(api_key: str):
if api_key != os.getenv('API_KEY'):
raise NewsClassifierAPIAuthenticationError("Authentication error: Invalid API key.")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
urls = request.query_params.getlist("urls")
error_details = exc.errors()
error_messages = []
for error in error_details:
loc = [*map(str, error['loc'])][-1]
msg = error['msg']
error_messages.append(f"{loc}: {msg}")
error_message = "; ".join(error_messages) if error_messages else ""
response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'categories': "", 'classifier_error': f'Validation Error: {error_message}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=422)
calibrated_model, label_encoder = load_model()
@app.post("/classify/", tags=["Classify"], response_model=List[SuccessfulResponse],
responses={
401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"},
500: {"model": ClassifierError, "description": "Classifier Error: Returned when the API couldn't classify even a single article"},
422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't satisfy the data type requirements"}
})
async def classify(q: URLList):
"""
Get categories of news articles by passing the list of URLs as input.
- **urls**: List of URLs (required)
- **key**: Authentication key (required)
"""
try:
logging.warning("Entering classify()")
urls = ""
scraped_texts = ""
scrape_errors = ""
labels = ""
probs = 0
request_json = q.json()
request_json = json.loads(request_json)
urls = request_json['urls']
api_key = request_json['key']
_ = authenticate_key(api_key)
scraped_texts, scrape_errors = await scrape_urls(urls)
unique_scraped_texts = [*set(scraped_texts)]
if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1):
raise NewsClassifierAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs")
labels, probs = await predict_news_classes(urls, scraped_texts, calibrated_model, label_encoder)
label_prob = [{"label": "", "calibrated_prediction_proba": 0}
if t == "" else {"label": l, "calibrated_prediction_proba": p}
for l, p, t in zip(labels, probs, scraped_texts)]
status_code = 200
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': label_prob, 'classifer_error': ''}
except Exception as e:
status_code = 500
if e.__class__.__name__ == "NewsClassifierAPIAuthenticationError":
status_code = 401
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': "", 'classifier_error': f'Error: {e}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=status_code)
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)
|