ksvmuralidhar commited on
Commit
83d8595
·
verified ·
1 Parent(s): bdb5934

Upload 10 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+ WORKDIR /code
3
+ COPY ./requirements.txt /code/requirements.txt
4
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
5
+ RUN apt update && apt install -y ffmpeg
6
+ RUN apt -y install wget
7
+ RUN apt -y install unzip
8
+
9
+ RUN apt-get install -y \
10
+ gnupg \
11
+ ca-certificates \
12
+ libglib2.0-0 \
13
+ libsm6 \
14
+ libxext6 \
15
+ libxrender-dev \
16
+ libfontconfig1 \
17
+ libnss3 \
18
+ libatk-bridge2.0-0 \
19
+ libatk1.0-0 \
20
+ libatspi2.0-0 \
21
+ libcups2 \
22
+ libcurl4 \
23
+ libgtk-3-0 \
24
+ libnspr4 \
25
+ libxcomposite1 \
26
+ libxdamage1 \
27
+ xdg-utils \
28
+ fonts-liberation \
29
+ libu2f-udev \
30
+ && rm -rf /var/lib/apt/lists/*
31
+
32
+ RUN wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb && \
33
+ dpkg -i google-chrome-stable_current_amd64.deb && \
34
+ apt-get -f install -y && \
35
+ rm google-chrome-stable_current_amd64.deb
36
+
37
+ RUN useradd -m -u 1000 user
38
+ USER user
39
+ ENV HOME=/home/user \
40
+ PATH=/home/user/.local/bin:$PATH \
41
+ CHROMEDRIVERURL=https://storage.googleapis.com/chrome-for-testing-public/127.0.6533.119/linux64/chromedriver-linux64.zip \
42
+ CHROMEDRIVERFILENAME=chromedriver-linux64.zip
43
+
44
+
45
+ WORKDIR $HOME/app
46
+
47
+ COPY --chown=user . $HOME/app
48
+
49
+ RUN wget -P $HOME/app $CHROMEDRIVERURL
50
+ RUN unzip $HOME/app/$CHROMEDRIVERFILENAME
51
+ RUN rm $HOME/app/$CHROMEDRIVERFILENAME
52
+
53
+ RUN chmod +x $HOME/app/chromedriver-linux64/chromedriver
54
+
55
+ RUN ls -ltr
56
+
57
+ EXPOSE 7860
58
+ ENTRYPOINT ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "3"]
api.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cloudpickle
2
+ import os
3
+ import tensorflow as tf
4
+ from scraper import scrape_text
5
+ from fastapi import FastAPI, Response, Request
6
+ from typing import List, Dict
7
+ from pydantic import BaseModel, Field
8
+ from fastapi.exceptions import RequestValidationError
9
+ import uvicorn
10
+ import json
11
+ import logging
12
+ import multiprocessing
13
+ from news_classifier import predict_news_classes
14
+
15
+
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
18
+
19
+
20
+ def load_model():
21
+ logging.warning('Entering load transformer')
22
+ with open("classification_models/label_encoder.bin", "rb") as model_file_obj:
23
+ label_encoder = cloudpickle.load(model_file_obj)
24
+
25
+ with open("classification_models/calibrated_model.bin", "rb") as model_file_obj:
26
+ calibrated_model = cloudpickle.load(model_file_obj)
27
+
28
+ tflite_model_path = os.path.join("classification_models", "model.tflite")
29
+ calibrated_model.estimator.tflite_model_path = tflite_model_path
30
+ logging.warning('Exiting load transformer')
31
+ return calibrated_model, label_encoder
32
+
33
+
34
+ async def scrape_urls(urls):
35
+ logging.warning('Entering scrape_urls()')
36
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
37
+
38
+ results = []
39
+ for url in urls:
40
+ f = pool.apply_async(scrape_text, [url]) # asynchronously scraping text
41
+ results.append(f) # appending result to results
42
+
43
+ scraped_texts = []
44
+ scrape_errors = []
45
+ for f in results:
46
+ t, e = f.get(timeout=120)
47
+ scraped_texts.append(t)
48
+ scrape_errors.append(e)
49
+ pool.close()
50
+ pool.join()
51
+ logging.warning('Exiting scrape_urls()')
52
+ return scraped_texts, scrape_errors
53
+
54
+
55
+ description = '''API to classify news articles into categories from their URLs.\n
56
+ Categories = ASTROLOGY, BUSINESS, EDUCATION, ENTERTAINMENT, HEALTH, NATION, SCIENCE, SPORTS, TECHNOLOGY, WEATHER, WORLD'''
57
+ app = FastAPI(title='News Classifier API',
58
+ description=description,
59
+ version="0.0.1",
60
+ contact={
61
+ "name": "Author: KSV Muralidhar",
62
+ "url": "https://ksvmuralidhar.in"
63
+ },
64
+ license_info={
65
+ "name": "License: MIT",
66
+ "identifier": "MIT"
67
+ },
68
+ swagger_ui_parameters={"defaultModelsExpandDepth": -1})
69
+
70
+
71
+ class URLList(BaseModel):
72
+ urls: List[str] = Field(..., description="List of URLs of news articles to classify")
73
+ key: str = Field(..., description="Authentication Key")
74
+
75
+ class Categories(BaseModel):
76
+ label: str = Field(..., description="category label")
77
+ calibrated_prediction_proba: float = Field(...,
78
+ description="calibrated prediction probability (confidence)")
79
+
80
+ class SuccessfulResponse(BaseModel):
81
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
82
+ scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
83
+ scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
84
+ category: Categories = Field(..., description="Dict of category label of news articles along with calibrated prediction_proba")
85
+ classifier_error: str = Field("", description="Empty string as the response code is 200")
86
+
87
+ class AuthenticationError(BaseModel):
88
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
89
+ scraped_texts: str = Field("", description="Empty string as authentication failed")
90
+ scrape_errors: str = Field("", description="Empty string as authentication failed")
91
+ category: str = Field("", description="Empty string as authentication failed")
92
+ classifier_error: str = Field("Error: Authentication error: Invalid API key.")
93
+
94
+ class ClassifierError(BaseModel):
95
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
96
+ scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
97
+ scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
98
+ category: str = Field("", description="Empty string as classifier encountered an error")
99
+ classifier_error: str = Field("Error: Classifier Error with a message describing the error")
100
+
101
+ class InputValidationError(BaseModel):
102
+ urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
103
+ scraped_texts: str = Field("", description="Empty string as validation failed")
104
+ scrape_errors: str = Field("", description="Empty string as validation failed")
105
+ category: str = Field("", description="Empty string as validation failed")
106
+ classifier_error: str = Field("Validation Error with a message describing the error")
107
+
108
+
109
+ class NewsClassifierAPIAuthenticationError(Exception):
110
+ pass
111
+
112
+ class NewsClassifierAPIScrapingError(Exception):
113
+ pass
114
+
115
+
116
+ def authenticate_key(api_key: str):
117
+ if api_key != os.getenv('API_KEY'):
118
+ raise NewsClassifierAPIAuthenticationError("Authentication error: Invalid API key.")
119
+
120
+
121
+ @app.exception_handler(RequestValidationError)
122
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
123
+ urls = request.query_params.getlist("urls")
124
+ error_details = exc.errors()
125
+ error_messages = []
126
+ for error in error_details:
127
+ loc = [*map(str, error['loc'])][-1]
128
+ msg = error['msg']
129
+ error_messages.append(f"{loc}: {msg}")
130
+ error_message = "; ".join(error_messages) if error_messages else ""
131
+ response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'categories': "", 'classifier_error': f'Validation Error: {error_message}'}
132
+ json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
133
+ return Response(content=json_str, media_type='application/json', status_code=422)
134
+
135
+
136
+ calibrated_model, label_encoder = load_model()
137
+
138
+ @app.post("/classify/", tags=["Classify"], response_model=List[SuccessfulResponse],
139
+ responses={
140
+ 401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"},
141
+ 500: {"model": ClassifierError, "description": "Classifier Error: Returned when the API couldn't classify even a single article"},
142
+ 422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't satisfy the data type requirements"}
143
+ })
144
+ async def classify(q: URLList):
145
+ """
146
+ Get categories of news articles by passing the list of URLs as input.
147
+ - **urls**: List of URLs (required)
148
+ - **key**: Authentication key (required)
149
+ """
150
+ try:
151
+ logging.warning("Entering classify()")
152
+ urls = ""
153
+ scraped_texts = ""
154
+ scrape_errors = ""
155
+ labels = ""
156
+ probs = 0
157
+ request_json = q.json()
158
+ request_json = json.loads(request_json)
159
+ urls = request_json['urls']
160
+ api_key = request_json['key']
161
+ _ = authenticate_key(api_key)
162
+ scraped_texts, scrape_errors = await scrape_urls(urls)
163
+
164
+ unique_scraped_texts = [*set(scraped_texts)]
165
+ if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1):
166
+ raise NewsClassifierAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs")
167
+
168
+ labels, probs = await predict_news_classes(urls, scraped_texts, calibrated_model, label_encoder)
169
+ label_prob = [{"label": "", "calibrated_prediction_proba": 0}
170
+ if t == "" else {"label": l, "calibrated_prediction_proba": p}
171
+ for l, p, t in zip(labels, probs, scraped_texts)]
172
+ status_code = 200
173
+ response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': label_prob, 'classifer_error': ''}
174
+ except Exception as e:
175
+ status_code = 500
176
+ if e.__class__.__name__ == "NewsClassifierAPIAuthenticationError":
177
+ status_code = 401
178
+ response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'categories': "", 'classifier_error': f'Error: {e}'}
179
+
180
+ json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
181
+ return Response(content=json_str, media_type='application/json', status_code=status_code)
182
+
183
+
184
+ if __name__ == '__main__':
185
+ uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)
calibrated_classifier.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.dummy import DummyClassifier
2
+ from tqdm import tqdm
3
+ import multiprocessing
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from transformers import DistilBertTokenizerFast
7
+
8
+
9
+
10
+ class PredictProba(DummyClassifier):
11
+ def __init__(self, tflite_model_path: str, classes_: list, n_tokens: int):
12
+ self.classes_ = classes_ # required attribute for an estimator to be used in calibration classifier
13
+ self.n_tokens = n_tokens
14
+ self.tflite_model_path = tflite_model_path
15
+
16
+
17
+ def fit(self, x, y):
18
+ print('called fit')
19
+ return self # fit method is required for an estimator to be used in calibration classifier
20
+
21
+ @staticmethod
22
+ def get_token_batches(attention_mask, input_ids, batch_size: int=8):
23
+ n_texts = len(attention_mask)
24
+ n_batches = int(np.ceil(n_texts / batch_size))
25
+ if n_texts <= batch_size:
26
+ n_batches = 1
27
+
28
+ attention_mask_batches = []
29
+ input_ids_batches = []
30
+
31
+ for i in range(n_batches):
32
+ if i != n_batches-1:
33
+ attention_mask_batches.append(attention_mask[i*batch_size: batch_size*(i+1)])
34
+ input_ids_batches.append(input_ids[i*batch_size: batch_size*(i+1)])
35
+ else:
36
+ attention_mask_batches.append(attention_mask[i*batch_size:])
37
+ input_ids_batches.append(input_ids[i*batch_size:])
38
+
39
+ return attention_mask_batches, input_ids_batches
40
+
41
+
42
+ def get_batch_inference(self, batch_size, attention_mask, input_ids):
43
+ interpreter = tf.lite.Interpreter(model_path=self.tflite_model_path)
44
+ interpreter.allocate_tensors()
45
+ input_details = interpreter.get_input_details()
46
+ output_details = interpreter.get_output_details()[0]
47
+ interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, self.n_tokens])
48
+ interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, self.n_tokens])
49
+ interpreter.resize_tensor_input(output_details['index'],[batch_size, len(self.classes_)])
50
+ interpreter.allocate_tensors()
51
+ interpreter.set_tensor(input_details[0]["index"], attention_mask)
52
+ interpreter.set_tensor(input_details[1]["index"], input_ids)
53
+ interpreter.invoke()
54
+ tflite_pred = interpreter.get_tensor(output_details["index"])
55
+ return tflite_pred
56
+
57
+ def inference(self, texts):
58
+ model_checkpoint = "distilbert-base-uncased"
59
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
60
+ tokens = tokenizer(texts, max_length=self.n_tokens, padding="max_length",
61
+ truncation=True, return_tensors="tf")
62
+ attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
63
+ attention_mask_batches, input_ids_batches = self.get_token_batches(attention_mask, input_ids)
64
+
65
+
66
+
67
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
68
+ results = []
69
+ for attention_mask, input_ids in zip(attention_mask_batches, input_ids_batches):
70
+ f = pool.apply_async(self.get_batch_inference, args=(len(attention_mask), attention_mask, input_ids))
71
+ results.append(f)
72
+
73
+ all_predictions = np.array([])
74
+ for n_batch in tqdm(range(len(results))):
75
+ tflite_pred = results[n_batch].get(timeout=360)
76
+ if n_batch == 0:
77
+ all_predictions = tflite_pred
78
+ else:
79
+ all_predictions = np.concatenate((all_predictions, tflite_pred), axis=0)
80
+ return all_predictions
81
+
82
+ def predict_proba(self, X, y=None):
83
+ predict_prob = self.inference(X)
84
+ return predict_prob
85
+
classification_models/calibrated_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b154508f547e10b14021fb7b004dc6c25558fbe1c6942706cfa843b6976a2ac2
3
+ size 4293
classification_models/label_encoder.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92b26f07332ebdd93c8f8f2e8378ecba05415eb3ae6713bd9b1f4289d921c26f
3
+ size 370
classification_models/model.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15140df1981b0d4edd6009ac340a481e344ffb663fc449ae2fec1e69ee931615
3
+ size 67002528
config.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ SCRAPER_TIMEOUT = 20
2
+ CHROME_DRIVER_PATH = "./chromedriver-linux64/chromedriver"
news_classifier.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import logging
4
+
5
+
6
+ def find_path(url):
7
+ if url == '':
8
+ return ''
9
+ url = url.replace("-/-", "-")
10
+ url_split = url.replace("https://", "")
11
+ url_split = url_split.replace("www.", "")
12
+ url_split = url_split.strip()
13
+ url = url.replace("//", "/")
14
+ url = url.replace("https/timesofindia-indiatimes-com", "")
15
+ url_split = url_split.split("/")
16
+ url_split = [u for u in url_split if (u != "") and
17
+ (u != "articleshow") and
18
+ (u.find(".cms")==-1) and
19
+ (u.find(".ece")==-1) and
20
+ (u.find(".htm")==-1) and
21
+ (len(u.split('-')) <= 5) and
22
+ (u.find(" ") == -1)
23
+ ]
24
+ if len(url_split) > 2:
25
+ url_split = "/".join(url_split[1:])
26
+ else:
27
+ if len(url_split) > 0:
28
+ url_split = url_split[-1]
29
+ else:
30
+ url_split = '-'
31
+ return url_split
32
+
33
+
34
+ async def parse_prediction(tflite_pred, label_encoder):
35
+ tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
36
+ tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax)
37
+ tflite_pred_prob = np.max(tflite_pred, axis=1)
38
+ return tflite_pred_label, tflite_pred_prob
39
+
40
+
41
+ async def model_inference(text: list, calibrated_model, label_encoder):
42
+ logging.info('Entering news_classifier.model_inference()')
43
+
44
+ logging.info(f'Samples to predict: {len(text)}')
45
+ if text != "":
46
+ tflite_pred = calibrated_model.predict_proba(text)
47
+ tflite_pred = await parse_prediction(tflite_pred, label_encoder)
48
+ logging.info('Exiting news_classifier.model_inference()')
49
+ return tflite_pred
50
+
51
+
52
+ async def predict_news_classes(urls: list, texts: list, calibrated_model, label_encoder):
53
+ url_paths = [*map(find_path, urls)]
54
+ paths_texts = [f"{p}. {t}" for p, t in zip(url_paths, texts)]
55
+ label, prob = await model_inference(paths_texts, calibrated_model, label_encoder)
56
+ return label, prob
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.39.3
2
+ tensorflow==2.15.0
3
+ unidecode
4
+ tf-keras==2.15.0
5
+ selenium==4.19.0
6
+ fastapi
7
+ pydantic
8
+ uvicorn
9
+ undetected-chromedriver
10
+ scikit-learn==1.2.2
11
+ cloudpickle
12
+ numpy==1.24.3
scraper.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from selenium import webdriver
2
+ from selenium.webdriver.common.by import By
3
+ import undetected_chromedriver as uc
4
+ import re
5
+ import logging
6
+ import os
7
+ import time
8
+ import random
9
+ from config import SCRAPER_TIMEOUT, CHROME_DRIVER_PATH
10
+
11
+
12
+ def get_text(url, n_words=15):
13
+ try:
14
+ driver = None
15
+ logging.warning(f"Initiated Scraping {url}")
16
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"
17
+ options = uc.ChromeOptions()
18
+ options.add_argument("--headless")
19
+ options.add_argument(f"user-agent={user_agent}")
20
+ options.add_argument("--blink-settings=imagesEnabled=false")
21
+ options.add_argument("--disable-images")
22
+ options.add_argument("--disable-blink-features=AutomationControlled")
23
+ options.add_argument("--disable-dev-shm-usage")
24
+
25
+ # options.add_argument("--disable-extensions")
26
+ # options.add_argument("--autoplay-policy=no-user-gesture-required")
27
+ # options.add_argument("--disable-infobars")
28
+ # options.add_argument("--disable-gpu")
29
+
30
+ driver = uc.Chrome(version_main=127, options=options, driver_executable_path=CHROME_DRIVER_PATH)
31
+ time.sleep(random.uniform(0.5, 1.5))
32
+ driver.set_page_load_timeout(SCRAPER_TIMEOUT)
33
+ driver.set_script_timeout(SCRAPER_TIMEOUT)
34
+ driver.implicitly_wait(3)
35
+ driver.get(url)
36
+ elem = driver.find_element(By.TAG_NAME, "body").text
37
+ sents = elem.split("\n")
38
+ sentence_list = []
39
+ for sent in sents:
40
+ sent = sent.strip()
41
+ if (len(sent.split()) >= n_words) and (len(re.findall(r"^\w.+[^\w\)\s]$", sent))>0):
42
+ sentence_list.append(sent)
43
+ driver.close()
44
+ driver.quit()
45
+ logging.warning("Closed Webdriver")
46
+ logging.warning("Successfully scraped text")
47
+ if len(sentence_list) < 3:
48
+ raise Exception("Found nothing to scrape.")
49
+ return "\n".join(sentence_list), ""
50
+ except Exception as e:
51
+ logging.warning(str(e))
52
+ if driver:
53
+ driver.close()
54
+ driver.quit()
55
+ logging.warning("Closed Webdriver")
56
+ err_msg = str(e).split('\n')[0]
57
+ return "", err_msg
58
+
59
+
60
+ def scrape_text(url, n_words=15,max_retries=2):
61
+ scraped_text = ""
62
+ scrape_error = ""
63
+ try:
64
+ n_tries = 1
65
+ while (n_tries <= max_retries) and (scraped_text == ""):
66
+ scraped_text, scrape_error = get_text(url=url, n_words=n_words)
67
+ n_tries += 1
68
+ return scraped_text, scrape_error
69
+ except Exception as e:
70
+ err_msg = str(e).split('\n')[0]
71
+ return "", err_msg