ksvmuralidhar commited on
Commit
32ded02
1 Parent(s): 27dff41

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +186 -185
api.py CHANGED
@@ -1,185 +1,186 @@
1
- import cloudpickle
2
- import os
3
- import tensorflow as tf
4
- from scraper import scrape_text
5
- from fastapi import FastAPI, Response, Request
6
- from typing import List, Dict
7
- from pydantic import BaseModel, Field
8
- from fastapi.exceptions import RequestValidationError
9
- import uvicorn
10
- import json
11
- import logging
12
- import multiprocessing
13
- from news_classifier import predict_news_classes
14
-
15
-
16
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
18
-
19
-
20
- def load_model():
21
- logging.warning('Entering load transformer')
22
- with open("classification_models/label_encoder.bin", "rb") as model_file_obj:
23
- label_encoder = cloudpickle.load(model_file_obj)
24
-
25
- with open("classification_models/calibrated_model.bin", "rb") as model_file_obj:
26
- calibrated_model = cloudpickle.load(model_file_obj)
27
-
28
- tflite_model_path = os.path.join("classification_models", "model.tflite")
29
- calibrated_model.estimator.tflite_model_path = tflite_model_path
30
- logging.warning('Exiting load transformer')
31
- return calibrated_model, label_encoder
32
-
33
-
34
- async def scrape_urls(urls):
35
- logging.warning('Entering scrape_urls()')
36
- pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
37
-
38
- results = []
39
- for url in urls:
40
- f = pool.apply_async(scrape_text, [url]) # asynchronously scraping text
41
- results.append(f) # appending result to results
42
-
43
- scraped_texts = []
44
- scrape_errors = []
45
- for f in results:
46
- t, e = f.get(timeout=120)
47
- scraped_texts.append(t)
48
- scrape_errors.append(e)
49
- pool.close()
50
- pool.join()
51
- logging.warning('Exiting scrape_urls()')
52
- return scraped_texts, scrape_errors
53
-
54
-
55
- description = '''API to classify news articles into categories from their URLs.\n
56
- Categories = ASTROLOGY, BUSINESS, EDUCATION, ENTERTAINMENT, HEALTH, NATION, SCIENCE, SPORTS, TECHNOLOGY, WEATHER, WORLD'''
57
- app = FastAPI(title='News Classifier API',
58
- description=description,
59
- version="0.0.1",
60
- contact={
61
- "name": "Author: KSV Muralidhar",
62
- "url": "https://ksvmuralidhar.in"
63
- },
64
- license_info={
65
- "name": "License: MIT",
66
- "identifier": "MIT"
67
- },
68
- swagger_ui_parameters={"defaultModelsExpandDepth": -1})
69
-
70
-
71
- class URLList(BaseModel):
72
- urls: List[str] = Field(..., description="List of URLs of news articles to classify")
73
- key: str = Field(..., description="Authentication Key")
74
-
75
- class Categories(BaseModel):
76
- label: str = Field(..., description="category label")
77
- calibrated_prediction_proba: float = Field(...,
78
- description="calibrated prediction probability (confidence)")
79
-
80
- class SuccessfulResponse(BaseModel):
81
- urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
82
- scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
83
- scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
84
- category: Categories = Field(..., description="Dict of category label of news articles along with calibrated prediction_proba")
85
- classifier_error: str = Field("", description="Empty string as the response code is 200")
86
-
87
- class AuthenticationError(BaseModel):
88
- urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
89
- scraped_texts: str = Field("", description="Empty string as authentication failed")
90
- scrape_errors: str = Field("", description="Empty string as authentication failed")
91
- category: str = Field("", description="Empty string as authentication failed")
92
- classifier_error: str = Field("Error: Authentication error: Invalid API key.")
93
-
94
- class ClassifierError(BaseModel):
95
- urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
96
- scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
97
- scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
98
- category: str = Field("", description="Empty string as classifier encountered an error")
99
- classifier_error: str = Field("Error: Classifier Error with a message describing the error")
100
-
101
- class InputValidationError(BaseModel):
102
- urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
103
- scraped_texts: str = Field("", description="Empty string as validation failed")
104
- scrape_errors: str = Field("", description="Empty string as validation failed")
105
- category: str = Field("", description="Empty string as validation failed")
106
- classifier_error: str = Field("Validation Error with a message describing the error")
107
-
108
-
109
- class NewsClassifierAPIAuthenticationError(Exception):
110
- pass
111
-
112
- class NewsClassifierAPIScrapingError(Exception):
113
- pass
114
-
115
-
116
- def authenticate_key(api_key: str):
117
- if api_key != os.getenv('API_KEY'):
118
- raise NewsClassifierAPIAuthenticationError("Authentication error: Invalid API key.")
119
-
120
-
121
- @app.exception_handler(RequestValidationError)
122
- async def validation_exception_handler(request: Request, exc: RequestValidationError):
123
- urls = request.query_params.getlist("urls")
124
- error_details = exc.errors()
125
- error_messages = []
126
- for error in error_details:
127
- loc = [*map(str, error['loc'])][-1]
128
- msg = error['msg']
129
- error_messages.append(f"{loc}: {msg}")
130
- error_message = "; ".join(error_messages) if error_messages else ""
131
- response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'categories': "", 'classifier_error': f'Validation Error: {error_message}'}
132
- json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
133
- return Response(content=json_str, media_type='application/json', status_code=422)
134
-
135
-
136
- calibrated_model, label_encoder = load_model()
137
-
138
- @app.post("/classify/", tags=["Classify"], response_model=List[SuccessfulResponse],
139
- responses={
140
- 401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"},
141
- 500: {"model": ClassifierError, "description": "Classifier Error: Returned when the API couldn't classify even a single article"},
142
- 422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't satisfy the data type requirements"}
143
- })
144
- async def classify(q: URLList):
145
- """
146
- Get categories of news articles by passing the list of URLs as input.
147
- - **urls**: List of URLs (required)
148
- - **key**: Authentication key (required)
149
- """
150
- try:
151
- logging.warning("Entering classify()")
152
- urls = ""
153
- scraped_texts = ""
154
- scrape_errors = ""
155
- labels = ""
156
- probs = 0
157
- request_json = q.json()
158
- request_json = json.loads(request_json)
159
- urls = request_json['urls']
160
- api_key = request_json['key']
161
- _ = authenticate_key(api_key)
162
- scraped_texts, scrape_errors = await scrape_urls(urls)
163
-
164
- unique_scraped_texts = [*set(scraped_texts)]
165
- if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1):
166
- raise NewsClassifierAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs")
167
-
168
- labels, probs = await predict_news_classes(urls, scraped_texts, calibrated_model, label_encoder)
169
- label_prob = [{"label": "", "calibrated_prediction_proba": 0}
170
- if t == "" else {"label": l, "calibrated_prediction_proba": p}
171
- for l, p, t in zip(labels, probs, scraped_texts)]
172
- status_code = 200
173
- response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': label_prob, 'classifer_error': ''}
174
- except Exception as e:
175
- status_code = 500
176
- if e.__class__.__name__ == "NewsClassifierAPIAuthenticationError":
177
- status_code = 401
178
- response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': "", 'classifier_error': f'Error: {e}'}
179
-
180
- json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
181
- return Response(content=json_str, media_type='application/json', status_code=status_code)
182
-
183
-
184
- if __name__ == '__main__':
185
- uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)
 
 
1
+ import cloudpickle
2
+ import os
3
+ import tensorflow as tf
4
+ from scraper import scrape_text
5
+ from fastapi import FastAPI, Response, Request
6
+ from typing import List, Dict
7
+ from pydantic import BaseModel, Field
8
+ from fastapi.exceptions import RequestValidationError
9
+ import uvicorn
10
+ import json
11
+ import logging
12
+ import multiprocessing
13
+ from news_classifier import predict_news_classes
14
+ from config import SCRAPER_MAX_RETRIES
15
+
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
19
+
20
+
21
+ def load_model():
22
+ logging.warning('Entering load transformer')
23
+ with open("classification_models/label_encoder.bin", "rb") as model_file_obj:
24
+ label_encoder = cloudpickle.load(model_file_obj)
25
+
26
+ with open("classification_models/calibrated_model.bin", "rb") as model_file_obj:
27
+ calibrated_model = cloudpickle.load(model_file_obj)
28
+
29
+ tflite_model_path = os.path.join("classification_models", "model.tflite")
30
+ calibrated_model.estimator.tflite_model_path = tflite_model_path
31
+ logging.warning('Exiting load transformer')
32
+ return calibrated_model, label_encoder
33
+
34
+
35
+ async def scrape_urls(urls):
36
+ logging.warning('Entering scrape_urls()')
37
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
38
+
39
+ results = []
40
+ for url in urls:
41
+ f = pool.apply_async(scrape_text, [url, SCRAPER_MAX_RETRIES]) # asynchronously scraping text
42
+ results.append(f) # appending result to results
43
+
44
+ scraped_texts = []
45
+ scrape_errors = []
46
+ for f in results:
47
+ t, e = f.get(timeout=120)
48
+ scraped_texts.append(t)
49
+ scrape_errors.append(e)
50
+ pool.close()
51
+ pool.join()
52
+ logging.warning('Exiting scrape_urls()')
53
+ return scraped_texts, scrape_errors
54
+
55
+
56
+ description = '''API to classify news articles into categories from their URLs.\n
57
+ Categories = ASTROLOGY, BUSINESS, EDUCATION, ENTERTAINMENT, HEALTH, NATION, SCIENCE, SPORTS, TECHNOLOGY, WEATHER, WORLD'''
58
+ app = FastAPI(title='News Classifier API',
59
+ description=description,
60
+ version="0.0.1",
61
+ contact={
62
+ "name": "Author: KSV Muralidhar",
63
+ "url": "https://ksvmuralidhar.in"
64
+ },
65
+ license_info={
66
+ "name": "License: MIT",
67
+ "identifier": "MIT"
68
+ },
69
+ swagger_ui_parameters={"defaultModelsExpandDepth": -1})
70
+
71
+
72
+ class URLList(BaseModel):
73
+ urls: List[str] = Field(..., description="List of URLs of news articles to classify")
74
+ key: str = Field(..., description="Authentication Key")
75
+
76
+ class Categories(BaseModel):
77
+ label: str = Field(..., description="category label")
78
+ calibrated_prediction_proba: float = Field(...,
79
+ description="calibrated prediction probability (confidence)")
80
+
81
+ class SuccessfulResponse(BaseModel):
82
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
83
+ scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
84
+ scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
85
+ category: Categories = Field(..., description="Dict of category label of news articles along with calibrated prediction_proba")
86
+ classifier_error: str = Field("", description="Empty string as the response code is 200")
87
+
88
+ class AuthenticationError(BaseModel):
89
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
90
+ scraped_texts: str = Field("", description="Empty string as authentication failed")
91
+ scrape_errors: str = Field("", description="Empty string as authentication failed")
92
+ category: str = Field("", description="Empty string as authentication failed")
93
+ classifier_error: str = Field("Error: Authentication error: Invalid API key.")
94
+
95
+ class ClassifierError(BaseModel):
96
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
97
+ scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
98
+ scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
99
+ category: str = Field("", description="Empty string as classifier encountered an error")
100
+ classifier_error: str = Field("Error: Classifier Error with a message describing the error")
101
+
102
+ class InputValidationError(BaseModel):
103
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
104
+ scraped_texts: str = Field("", description="Empty string as validation failed")
105
+ scrape_errors: str = Field("", description="Empty string as validation failed")
106
+ category: str = Field("", description="Empty string as validation failed")
107
+ classifier_error: str = Field("Validation Error with a message describing the error")
108
+
109
+
110
+ class NewsClassifierAPIAuthenticationError(Exception):
111
+ pass
112
+
113
+ class NewsClassifierAPIScrapingError(Exception):
114
+ pass
115
+
116
+
117
+ def authenticate_key(api_key: str):
118
+ if api_key != os.getenv('API_KEY'):
119
+ raise NewsClassifierAPIAuthenticationError("Authentication error: Invalid API key.")
120
+
121
+
122
+ @app.exception_handler(RequestValidationError)
123
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
124
+ urls = request.query_params.getlist("urls")
125
+ error_details = exc.errors()
126
+ error_messages = []
127
+ for error in error_details:
128
+ loc = [*map(str, error['loc'])][-1]
129
+ msg = error['msg']
130
+ error_messages.append(f"{loc}: {msg}")
131
+ error_message = "; ".join(error_messages) if error_messages else ""
132
+ response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'categories': "", 'classifier_error': f'Validation Error: {error_message}'}
133
+ json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
134
+ return Response(content=json_str, media_type='application/json', status_code=422)
135
+
136
+
137
+ calibrated_model, label_encoder = load_model()
138
+
139
+ @app.post("/classify/", tags=["Classify"], response_model=List[SuccessfulResponse],
140
+ responses={
141
+ 401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"},
142
+ 500: {"model": ClassifierError, "description": "Classifier Error: Returned when the API couldn't classify even a single article"},
143
+ 422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't satisfy the data type requirements"}
144
+ })
145
+ async def classify(q: URLList):
146
+ """
147
+ Get categories of news articles by passing the list of URLs as input.
148
+ - **urls**: List of URLs (required)
149
+ - **key**: Authentication key (required)
150
+ """
151
+ try:
152
+ logging.warning("Entering classify()")
153
+ urls = ""
154
+ scraped_texts = ""
155
+ scrape_errors = ""
156
+ labels = ""
157
+ probs = 0
158
+ request_json = q.json()
159
+ request_json = json.loads(request_json)
160
+ urls = request_json['urls']
161
+ api_key = request_json['key']
162
+ _ = authenticate_key(api_key)
163
+ scraped_texts, scrape_errors = await scrape_urls(urls)
164
+
165
+ unique_scraped_texts = [*set(scraped_texts)]
166
+ if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1):
167
+ raise NewsClassifierAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs")
168
+
169
+ labels, probs = await predict_news_classes(urls, scraped_texts, calibrated_model, label_encoder)
170
+ label_prob = [{"label": "", "calibrated_prediction_proba": 0}
171
+ if t == "" else {"label": l, "calibrated_prediction_proba": p}
172
+ for l, p, t in zip(labels, probs, scraped_texts)]
173
+ status_code = 200
174
+ response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': label_prob, 'classifer_error': ''}
175
+ except Exception as e:
176
+ status_code = 500
177
+ if e.__class__.__name__ == "NewsClassifierAPIAuthenticationError":
178
+ status_code = 401
179
+ response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': "", 'classifier_error': f'Error: {e}'}
180
+
181
+ json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
182
+ return Response(content=json_str, media_type='application/json', status_code=status_code)
183
+
184
+
185
+ if __name__ == '__main__':
186
+ uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)