rdose commited on
Commit
1314a1a
·
1 Parent(s): 9fd0daa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -104
app.py CHANGED
@@ -1,12 +1,21 @@
1
 
 
 
 
 
 
 
2
 
3
  import numpy as np
 
 
4
  import onnxruntime
5
  import onnx
6
  import gradio as gr
7
- import requests
8
- import json
9
 
 
 
 
10
 
11
  try:
12
  from extractnet import Extractor
@@ -27,28 +36,20 @@ except ImportError:
27
 
28
  print('[i] Using',EXTRACTOR_NET)
29
 
30
- import math
31
- from transformers import AutoTokenizer
32
  import spacy
33
- import os
34
- from transformers import pipeline
35
- import itertools
36
- import pandas as pd
37
 
38
- # from bertopic import BERTopic
39
- # from huggingface_hub import hf_hub_url, cached_download
40
 
41
- # import nltk
42
- # nltk.download('stopwords')
43
- # nltk.download('wordnet')
44
- # nltk.download('omw-1.4')
45
- # from nltk.corpus import stopwords
46
- # from nltk.stem import WordNetLemmatizer
47
- # from nltk.stem import PorterStemmer
48
 
49
- # from unicodedata import normalize
50
 
51
- # import re
52
 
53
 
54
  OUT_HEADERS = ['E','S','G']
@@ -57,82 +58,85 @@ DF_SP500 = pd.read_csv('SP500_constituents.zip',compression=dict(method='zip'))
57
  MODEL_TRANSFORMER_BASED = "distilbert-base-uncased"
58
  MODEL_ONNX_FNAME = "ESG_classifier_batch.onnx"
59
  MODEL_SENTIMENT_ANALYSIS = "ProsusAI/finbert"
60
-
61
-
62
- # BERTOPIC_REPO_ID = "oMateos2020/BERTopic-paraphrase-MiniLM-L3-v2-51topics-guided-model3"
63
- # BERTOPIC_FILENAME = "BERTopic-paraphrase-MiniLM-L3-v2-51topics-guided-model3"
64
- # bertopic_model = BERTopic.load(cached_download(hf_hub_url(BERTOPIC_REPO_ID , BERTOPIC_FILENAME )), embedding_model="paraphrase-MiniLM-L3-v2")
65
-
66
- # def _topic_sanitize_word(text):
67
- # """Función realiza una primera limpieza-normalización del texto a traves de expresiones regex"""
68
- # text = re.sub(r'@[\w_]+|#[\w_]+|https?://[\w_./]+', '', text) # Elimina menciones y URL, esto sería más para Tweets pero por si hay alguna mención o URL al ser criticas web
69
- # text = re.sub('\S*@\S*\s?', '', text) # Elimina correos electronicos
70
- # text = re.sub(r'\((\d+)\)', '', text) #Elimina numeros entre parentesis
71
- # text = re.sub(r'^\d+', '', text) #Elimina numeros sueltos
72
- # text = re.sub(r'\n', '', text) #Elimina saltos de linea
73
- # text = re.sub('\s+', ' ', text) # Elimina espacios en blanco adicionales
74
- # text = re.sub(r'[“”]', '', text) # Elimina caracter citas
75
- # text = re.sub(r'[()]', '', text) # Elimina parentesis
76
- # text = re.sub('\.', '', text) # Elimina punto
77
- # text = re.sub('\,', '', text) # Elimina coma
78
- # text = re.sub('’s', '', text) # Elimina posesivos
79
- # #text = re.sub(r'-+', '', text) # Quita guiones para unir palabras compuestas (normalizaría algunos casos, exmujer y ex-mujer, todos a exmujer)
80
- # text = re.sub(r'\.{3}', ' ', text) # Reemplaza puntos suspensivos
81
- # # Esta exp regular se ha incluido "a mano" tras ver que era necesaria para algunos ejemplos
82
- # text = re.sub(r"([\.\?])", r"\1 ", text) # Introduce espacio despues de punto e interrogacion
83
- # # -> NFD (Normalization Form Canonical Decomposition) y eliminar diacríticos
84
- # text = re.sub(r"([^n\u0300-\u036f]|n(?!\u0303(?![\u0300-\u036f])))[\u0300-\u036f]+", r"\1",
85
- # normalize( "NFD", text), 0, re.I) # Eliminación de diacriticos (acentos y variantes puntuadas de caracteres por su forma simple excepto la 'ñ')
86
- # # -> NFC (Normalization Form Canonical Composition)
87
- # text = normalize( 'NFC', text)
88
-
89
- # return text.lower().strip()
90
-
91
- # def _topic_clean_text(text, lemmatize=True, stem=True):
92
- # words = text.split()
93
- # non_stopwords = [word for word in words if word not in stopwords.words('english')]
94
- # clean_text = [_topic_sanitize_word(word) for word in non_stopwords]
95
- # if lemmatize:
96
- # lemmatizer = WordNetLemmatizer()
97
- # clean_text = [lemmatizer.lemmatize(word) for word in clean_text]
98
- # if stem:
99
- # ps =PorterStemmer()
100
- # clean_text = [ps.stem(word) for word in clean_text]
101
-
102
- # return ' '.join(clean_text).strip()
103
-
104
-
105
- # #SECTOR_LIST = list(DF_SP500.Sector.unique())
106
- # SECTOR_LIST = ['Industry',
107
- # 'Health',
108
- # 'Technology',
109
- # 'Communication',
110
- # 'Consumer Staples',
111
- # 'Consumer Discretionary',
112
- # 'Utilities',
113
- # 'Financials',
114
- # 'Materials',
115
- # 'Real Estate',
116
- # 'Energy']
117
-
118
- # SECTOR_TOPICS = []
119
- # for sector in SECTOR_LIST:
120
- # topics, _ = bertopic_model.find_topics(_topic_clean_text(sector), top_n=5)
121
- # SECTOR_TOPICS.append(topics)
122
-
123
- # def _topic2sector(pred_topics):
124
- # out = []
125
- # for pred_topic in pred_topics:
126
- # relevant_sectors = []
127
- # for i in range(len(SECTOR_LIST)):
128
- # if pred_topic in SECTOR_TOPICS[i]:
129
- # relevant_sectors.append(list(DF_SP500.Sector.unique())[i])
130
- # out.append(relevant_sectors)
131
- # return out
132
-
133
- # def _inference_topic_match(text):
134
- # out, _ = bertopic_model.transform([_topic_clean_text(t) for t in text])
135
- # return out
 
 
 
136
 
137
  def get_company_sectors(extracted_names, threshold=0.95):
138
  '''
@@ -184,7 +188,6 @@ def filter_spans(spans, keep_longest=True):
184
  return result
185
 
186
 
187
-
188
  def _inference_ner_spancat(text, limit_outputs=10):
189
  nlp = spacy.load("en_pipeline")
190
  out = []
@@ -264,7 +267,7 @@ def _inference_classifier(text):
264
 
265
  return sigmoid(ort_outs[0])
266
 
267
- def inference(input_batch,isurl,use_archive,limit_companies=10):
268
  url_list = [] #Only used if isurl
269
  input_batch_content = []
270
  # if file_in.name is not "":
@@ -285,7 +288,8 @@ def inference(input_batch,isurl,use_archive,limit_companies=10):
285
  if isurl:
286
  print("[i] Data is URL")
287
  if use_archive:
288
- print("[i] Use chached URL from archive.org")
 
289
  for row_in in input_batch_r:
290
  if isinstance(row_in , list):
291
  url = row_in[0]
@@ -324,9 +328,10 @@ def inference(input_batch,isurl,use_archive,limit_companies=10):
324
  sentiment = _inference_sentiment_model_pipeline(input_batch_content )
325
  print("[i] Running NER using custom spancat inference...")
326
  ner_labels = _inference_ner_spancat(input_batch_content ,limit_outputs=limit_companies)
327
- # print("[i] BERTopic...")
328
- # topics = _inference_topic_match(input_batch_content)
329
-
 
330
  df = pd.DataFrame(prob_outs,columns =['E','S','G'])
331
  if isurl:
332
  df['URL'] = url_list
@@ -334,6 +339,7 @@ def inference(input_batch,isurl,use_archive,limit_companies=10):
334
  df['content_id'] = range(1, len(input_batch_r)+1)
335
  df['sent_lbl'] = [d['label'] for d in sentiment ]
336
  df['sent_score'] = [d['score'] for d in sentiment ]
 
337
  #df['sector_pred'] = pd.DataFrame(_topic2sector(topics)).iloc[:, 0]
338
  print("[i] Pandas output shape:",df.shape)
339
 
@@ -343,7 +349,9 @@ def inference(input_batch,isurl,use_archive,limit_companies=10):
343
  for idx in range(len(df.index)):
344
  if ner_labels[idx]: #not empty
345
  for ner in ner_labels[idx]:
346
-
 
 
347
  df = pd.concat( [df, df.loc[[idx]].assign(company=ner[0], sector=ner[1])], join='outer', ignore_index=True) #axis=0
348
 
349
  return df #ner_labels, {'E':float(prob_outs[0]),"S":float(prob_outs[1]),"G":float(prob_outs[2])},{sentiment['label']:float(sentiment['score'])},"**Summary:**\n\n" + summary
@@ -359,6 +367,7 @@ API input parameters:
359
  - List: list of text. Either list of Url of the news (english) or list of extracted news contents
360
  - 'Data type': int. 0=list is of extracted news contents, 1=list is of urls.
361
  - `use_archive`: boolean. The model will extract the archived version in archive.org of the url indicated. This is useful with old news and to bypass news behind paywall
 
362
  - `limit_companies`: integer. Number of found relevant companies to report.
363
 
364
  """
@@ -370,11 +379,12 @@ examples = [[ [['https://www.bbc.com/news/uk-62732447'],
370
  ['https://www.bbc.com/news/world-europe-62766867'],
371
  ['https://www.bbc.com/news/business-62524031'],
372
  ['https://www.bbc.com/news/business-62728621'],
373
- ['https://www.bbc.com/news/science-environment-62680423']],'url',False,5]]
374
  demo = gr.Interface(fn=inference,
375
  inputs=[gr.Dataframe(label='input batch', col_count=1, datatype='str', type='array', wrap=True),
376
  gr.Dropdown(label='data type', choices=['text','url'], type='index', value='url'),
377
- gr.Checkbox(label='if url parse cached in archive.org'),
 
378
  gr.Slider(minimum=1, maximum=10, step=1, label='Limit NER output', value=5)],
379
  outputs=[gr.Dataframe(label='output raw', col_count=1, type='pandas', wrap=True, header=OUT_HEADERS)],
380
  #gr.Label(label='Company'),
 
1
 
2
+ import os
3
+ import re
4
+ import math
5
+ import requests
6
+ import json
7
+ import itertools
8
 
9
  import numpy as np
10
+ import pandas as pd
11
+
12
  import onnxruntime
13
  import onnx
14
  import gradio as gr
 
 
15
 
16
+ from huggingface_hub import hf_hub_url, cached_download
17
+ from transformers import AutoTokenizer
18
+ from transformers import pipeline
19
 
20
  try:
21
  from extractnet import Extractor
 
36
 
37
  print('[i] Using',EXTRACTOR_NET)
38
 
 
 
39
  import spacy
 
 
 
 
40
 
41
+ from bertopic import BERTopic
 
42
 
43
+ import nltk
44
+ nltk.download('stopwords')
45
+ nltk.download('wordnet')
46
+ nltk.download('omw-1.4')
47
+ from nltk.corpus import stopwords
48
+ from nltk.stem import WordNetLemmatizer
49
+ from nltk.stem import PorterStemmer
50
 
51
+ from unicodedata import normalize
52
 
 
53
 
54
 
55
  OUT_HEADERS = ['E','S','G']
 
58
  MODEL_TRANSFORMER_BASED = "distilbert-base-uncased"
59
  MODEL_ONNX_FNAME = "ESG_classifier_batch.onnx"
60
  MODEL_SENTIMENT_ANALYSIS = "ProsusAI/finbert"
61
+ #MODEL3
62
+ #BERTOPIC_REPO_ID = "oMateos2020/BERTopic-paraphrase-MiniLM-L3-v2-51topics-guided-model3"
63
+ #BERTOPIC_FILENAME = "BERTopic-paraphrase-MiniLM-L3-v2-51topics-guided-model3"
64
+ #bertopic_model = BERTopic.load(cached_download(hf_hub_url(BERTOPIC_REPO_ID , BERTOPIC_FILENAME )), embedding_model="paraphrase-MiniLM-L3-v2")
65
+
66
+ BERTOPIC_REPO_ID = "oMateos2020/BERTopic-distilbert-base-nli-mean-tokens"
67
+ BERTOPIC_FILENAME = "BERTopic-distilbert-base-nli-mean-tokens"
68
+ bertopic_model = BERTopic.load(cached_download(hf_hub_url(BERTOPIC_REPO_ID , BERTOPIC_FILENAME )))
69
+
70
+ #SECTOR_LIST = list(DF_SP500.Sector.unique())
71
+ SECTOR_LIST = ['Industry',
72
+ 'Health',
73
+ 'Technology',
74
+ 'Communication',
75
+ 'Consumer Staples',
76
+ 'Consumer Discretionary',
77
+ 'Utilities',
78
+ 'Financials',
79
+ 'Materials',
80
+ 'Real Estate',
81
+ 'Energy']
82
+
83
+
84
+ def _topic_sanitize_word(text):
85
+ """Función realiza una primera limpieza-normalización del texto a traves de expresiones regex"""
86
+ text = re.sub(r'@[\w_]+|#[\w_]+|https?://[\w_./]+', '', text) # Elimina menciones y URL, esto sería más para Tweets pero por si hay alguna mención o URL al ser criticas web
87
+ text = re.sub('\S*@\S*\s?', '', text) # Elimina correos electronicos
88
+ text = re.sub(r'\((\d+)\)', '', text) #Elimina numeros entre parentesis
89
+ text = re.sub(r'^\d+', '', text) #Elimina numeros sueltos
90
+ text = re.sub(r'\n', '', text) #Elimina saltos de linea
91
+ text = re.sub('\s+', ' ', text) # Elimina espacios en blanco adicionales
92
+ text = re.sub(r'[“”]', '', text) # Elimina caracter citas
93
+ text = re.sub(r'[()]', '', text) # Elimina parentesis
94
+ text = re.sub('\.', '', text) # Elimina punto
95
+ text = re.sub('\,', '', text) # Elimina coma
96
+ text = re.sub('’s', '', text) # Elimina posesivos
97
+ #text = re.sub(r'-+', '', text) # Quita guiones para unir palabras compuestas (normalizaría algunos casos, exmujer y ex-mujer, todos a exmujer)
98
+ text = re.sub(r'\.{3}', ' ', text) # Reemplaza puntos suspensivos
99
+ # Esta exp regular se ha incluido "a mano" tras ver que era necesaria para algunos ejemplos
100
+ text = re.sub(r"([\.\?])", r"\1 ", text) # Introduce espacio despues de punto e interrogacion
101
+ # -> NFD (Normalization Form Canonical Decomposition) y eliminar diacríticos
102
+ text = re.sub(r"([^n\u0300-\u036f]|n(?!\u0303(?![\u0300-\u036f])))[\u0300-\u036f]+", r"\1",
103
+ normalize( "NFD", text), 0, re.I) # Eliminación de diacriticos (acentos y variantes puntuadas de caracteres por su forma simple excepto la 'ñ')
104
+ # -> NFC (Normalization Form Canonical Composition)
105
+ text = normalize( 'NFC', text)
106
+
107
+ return text.lower().strip()
108
+
109
+ def _topic_clean_text(text, lemmatize=True, stem=True):
110
+ words = text.split()
111
+ non_stopwords = [word for word in words if word not in stopwords.words('english')]
112
+ clean_text = [_topic_sanitize_word(word) for word in non_stopwords]
113
+ if lemmatize:
114
+ lemmatizer = WordNetLemmatizer()
115
+ clean_text = [lemmatizer.lemmatize(word) for word in clean_text]
116
+ if stem:
117
+ ps =PorterStemmer()
118
+ clean_text = [ps.stem(word) for word in clean_text]
119
+
120
+ return ' '.join(clean_text).strip()
121
+
122
+ SECTOR_TOPICS = []
123
+ for sector in SECTOR_LIST:
124
+ topics, _ = bertopic_model.find_topics(_topic_clean_text(sector), top_n=5)
125
+ SECTOR_TOPICS.append(topics)
126
+
127
+ def _topic2sector(pred_topics):
128
+ out = []
129
+ for pred_topic in pred_topics:
130
+ relevant_sectors = []
131
+ for i in range(len(SECTOR_LIST)):
132
+ if pred_topic in SECTOR_TOPICS[i]:
133
+ relevant_sectors.append(list(DF_SP500.Sector.unique())[i])
134
+ out.append(relevant_sectors)
135
+ return out
136
+
137
+ def _inference_topic_match(text):
138
+ out, _ = bertopic_model.transform([_topic_clean_text(t) for t in text])
139
+ return out
140
 
141
  def get_company_sectors(extracted_names, threshold=0.95):
142
  '''
 
188
  return result
189
 
190
 
 
191
  def _inference_ner_spancat(text, limit_outputs=10):
192
  nlp = spacy.load("en_pipeline")
193
  out = []
 
267
 
268
  return sigmoid(ort_outs[0])
269
 
270
+ def inference(input_batch,isurl,use_archive,filt_companies_topic,limit_companies=10):
271
  url_list = [] #Only used if isurl
272
  input_batch_content = []
273
  # if file_in.name is not "":
 
288
  if isurl:
289
  print("[i] Data is URL")
290
  if use_archive:
291
+ print("[i] Use chached URL from archive.org")
292
+ print("[i] Extracting contents using",EXTRACTOR_NET)
293
  for row_in in input_batch_r:
294
  if isinstance(row_in , list):
295
  url = row_in[0]
 
328
  sentiment = _inference_sentiment_model_pipeline(input_batch_content )
329
  print("[i] Running NER using custom spancat inference...")
330
  ner_labels = _inference_ner_spancat(input_batch_content ,limit_outputs=limit_companies)
331
+ print("[i] Extracting topic using custom BERTopic...")
332
+ topics = _inference_topic_match(input_batch_content)
333
+ news_sectors = _topic2sector(topics)
334
+
335
  df = pd.DataFrame(prob_outs,columns =['E','S','G'])
336
  if isurl:
337
  df['URL'] = url_list
 
339
  df['content_id'] = range(1, len(input_batch_r)+1)
340
  df['sent_lbl'] = [d['label'] for d in sentiment ]
341
  df['sent_score'] = [d['score'] for d in sentiment ]
342
+ df['topic'] = pd.DataFrame(news_sectors).iloc[:, 0]
343
  #df['sector_pred'] = pd.DataFrame(_topic2sector(topics)).iloc[:, 0]
344
  print("[i] Pandas output shape:",df.shape)
345
 
 
349
  for idx in range(len(df.index)):
350
  if ner_labels[idx]: #not empty
351
  for ner in ner_labels[idx]:
352
+ if filt_companies_topic:
353
+ if news_sectors[idx] != ner[1]:
354
+ continue
355
  df = pd.concat( [df, df.loc[[idx]].assign(company=ner[0], sector=ner[1])], join='outer', ignore_index=True) #axis=0
356
 
357
  return df #ner_labels, {'E':float(prob_outs[0]),"S":float(prob_outs[1]),"G":float(prob_outs[2])},{sentiment['label']:float(sentiment['score'])},"**Summary:**\n\n" + summary
 
367
  - List: list of text. Either list of Url of the news (english) or list of extracted news contents
368
  - 'Data type': int. 0=list is of extracted news contents, 1=list is of urls.
369
  - `use_archive`: boolean. The model will extract the archived version in archive.org of the url indicated. This is useful with old news and to bypass news behind paywall
370
+ - `filter_companies`: boolean. Filter companies by news' topic
371
  - `limit_companies`: integer. Number of found relevant companies to report.
372
 
373
  """
 
379
  ['https://www.bbc.com/news/world-europe-62766867'],
380
  ['https://www.bbc.com/news/business-62524031'],
381
  ['https://www.bbc.com/news/business-62728621'],
382
+ ['https://www.bbc.com/news/science-environment-62680423']],'url',False,False,5]]
383
  demo = gr.Interface(fn=inference,
384
  inputs=[gr.Dataframe(label='input batch', col_count=1, datatype='str', type='array', wrap=True),
385
  gr.Dropdown(label='data type', choices=['text','url'], type='index', value='url'),
386
+ gr.Checkbox(label='Parse cached in archive.org'),
387
+ gr.Checkbox(label='Filter out companies by topic'),
388
  gr.Slider(minimum=1, maximum=10, step=1, label='Limit NER output', value=5)],
389
  outputs=[gr.Dataframe(label='output raw', col_count=1, type='pandas', wrap=True, header=OUT_HEADERS)],
390
  #gr.Label(label='Company'),