Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -12,8 +12,10 @@ import os
|
|
12 |
from transformers import pipeline
|
13 |
import itertools
|
14 |
import pandas as pd
|
|
|
15 |
|
16 |
OUT_HEADERS = ['E','S','G']
|
|
|
17 |
|
18 |
MODEL_TRANSFORMER_BASED = "distilbert-base-uncased"
|
19 |
MODEL_ONNX_FNAME = "ESG_classifier_batch.onnx"
|
@@ -24,24 +26,59 @@ MODEL_SENTIMENT_ANALYSIS = "ProsusAI/finbert"
|
|
24 |
|
25 |
#API_HF_SENTIMENT_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment"
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def _inference_ner_spancat(text, summary, penalty=0.5, normalise=True, limit_outputs=10):
|
28 |
nlp = spacy.load("en_pipeline")
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
if str(comp_s) in comp_raw_text.keys():
|
37 |
-
comp_raw_text[str(comp_s)] = comp_raw_text[str(comp_s)] / penalty
|
38 |
-
temp_max = comp_raw_text[str(comp_s)]if comp_raw_text[str(comp_s)] > 1.0 else 0.0
|
39 |
-
exceeds_one = comp_raw_text[str(comp_s)] if temp_max > exceeds_one else exceeds_one
|
40 |
-
#This "exceeds_one" is a bit confusing. So the thing is that the penalty is reverted for each time the company appears in the summary and hence the value can exceed one when the company appears more than once. The normalisation means that all the other scores are divided by the maximum when any value exceeds one
|
41 |
-
if normalise and (exceeds_one > 1):
|
42 |
-
comp_raw_text = {k: v/exceeds_one for k, v in comp_raw_text.items()}
|
43 |
|
44 |
-
return
|
45 |
|
46 |
#def _inference_summary_model_pipeline(text):
|
47 |
# pipe = pipeline("text2text-generation", model=MODEL_SUMMARY_PEGASUS)
|
@@ -162,8 +199,10 @@ def inference(input_batch,isurl,use_archive,limit_companies=10):
|
|
162 |
print("[i] Running sentiment using",MODEL_SENTIMENT_ANALYSIS ,"inference...")
|
163 |
#sentiment = _inference_sentiment_model_via_api_query({"inputs": extracted['content']})
|
164 |
sentiment = _inference_sentiment_model_pipeline(input_batch_content )
|
|
|
165 |
#summary = _inference_summary_model_pipeline(input_batch_content )[0]['generated_text']
|
166 |
-
|
|
|
167 |
df = pd.DataFrame(prob_outs,columns =['E','S','G'])
|
168 |
if isurl:
|
169 |
df['URL'] = url_list
|
|
|
12 |
from transformers import pipeline
|
13 |
import itertools
|
14 |
import pandas as pd
|
15 |
+
import thefuzz
|
16 |
|
17 |
OUT_HEADERS = ['E','S','G']
|
18 |
+
DF_SP500 = pd.read_csv('SP500_constituents.zip',compression=dict(method='zip'))
|
19 |
|
20 |
MODEL_TRANSFORMER_BASED = "distilbert-base-uncased"
|
21 |
MODEL_ONNX_FNAME = "ESG_classifier_batch.onnx"
|
|
|
26 |
|
27 |
#API_HF_SENTIMENT_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment"
|
28 |
|
29 |
+
def get_company_sectors(extracted_names, threshold=0.95):
|
30 |
+
'''
|
31 |
+
'''
|
32 |
+
output = []
|
33 |
+
standard_names_tuples = []
|
34 |
+
for extracted_name in extracted_names:
|
35 |
+
name_match = thefuzz.process.extractOne(extracted_name,
|
36 |
+
DF_SP500.Name,
|
37 |
+
scorer=thefuzz.fuzz.token_set_ratio)
|
38 |
+
similarity = name_match[1]/100
|
39 |
+
if similarity >= threshold:
|
40 |
+
standard_names_tuples.append(name_match[:2])
|
41 |
+
|
42 |
+
for std_comp_name, _ in standard_names_tuples:
|
43 |
+
sectors = list(DF_SP500[['Name','Sector']].where(DF_SP500.Name == std_comp_name).dropna().itertuples(index=False, name=None))
|
44 |
+
output += sectors
|
45 |
+
return output
|
46 |
+
|
47 |
+
def filter_spans(spans, keep_longest=True):
|
48 |
+
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
|
49 |
+
creating named entities (where one token can only be part of one entity) or
|
50 |
+
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
|
51 |
+
longest span is preferred over shorter spans.
|
52 |
+
spans (Iterable[Span]): The spans to filter.
|
53 |
+
keep_longest (bool): Specify whether to keep longer or shorter spans.
|
54 |
+
RETURNS (List[Span]): The filtered spans.
|
55 |
+
"""
|
56 |
+
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
57 |
+
sorted_spans = sorted(spans, key=get_sort_key, reverse=keep_longest)
|
58 |
+
#print(f'sorted_spans: {sorted_spans}')
|
59 |
+
result = []
|
60 |
+
seen_tokens = set()
|
61 |
+
for span in sorted_spans:
|
62 |
+
# Check for end - 1 here because boundaries are inclusive
|
63 |
+
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
64 |
+
result.append(span)
|
65 |
+
seen_tokens.update(range(span.start, span.end))
|
66 |
+
result = sorted(result, key=lambda span: span.start)
|
67 |
+
return result
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
def _inference_ner_spancat(text, summary, penalty=0.5, normalise=True, limit_outputs=10):
|
72 |
nlp = spacy.load("en_pipeline")
|
73 |
+
out = []
|
74 |
+
for doc in nlp.pipe(text):
|
75 |
+
spans = doc.spans["sc"]
|
76 |
+
#comp_raw_text = dict( sorted( dict(zip([str(x) for x in spans],[float(x)*penalty for x in spans.attrs['scores']])).items(), key=lambda x: x[1], reverse=True) )
|
77 |
+
|
78 |
+
company_list = list(set([str(span).replace('\'s', '') for span in filter_spans(spans, keep_longest=True)]))[:limit_outputs]
|
79 |
+
out.append(get_company_sectors(company_list))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
return out
|
82 |
|
83 |
#def _inference_summary_model_pipeline(text):
|
84 |
# pipe = pipeline("text2text-generation", model=MODEL_SUMMARY_PEGASUS)
|
|
|
199 |
print("[i] Running sentiment using",MODEL_SENTIMENT_ANALYSIS ,"inference...")
|
200 |
#sentiment = _inference_sentiment_model_via_api_query({"inputs": extracted['content']})
|
201 |
sentiment = _inference_sentiment_model_pipeline(input_batch_content )
|
202 |
+
print("[i] Running NER using custom spancat inference...")
|
203 |
#summary = _inference_summary_model_pipeline(input_batch_content )[0]['generated_text']
|
204 |
+
ner_labels = _inference_ner_spancat(input_batch_content ,limit_outputs=limit_companies)
|
205 |
+
print(ner_labels)
|
206 |
df = pd.DataFrame(prob_outs,columns =['E','S','G'])
|
207 |
if isurl:
|
208 |
df['URL'] = url_list
|