Spaces:
Build error
Build error
import numpy as np | |
import onnxruntime | |
import onnx | |
import gradio as gr | |
import requests | |
import json | |
from extractnet import Extractor | |
import math | |
from transformers import AutoTokenizer | |
import spacy | |
import os | |
from transformers import pipeline | |
import itertools | |
import pandas as pd | |
OUT_HEADERS = ['E','S','G'] | |
DF_SP500 = pd.read_csv('SP500_constituents.zip',compression=dict(method='zip')) | |
MODEL_TRANSFORMER_BASED = "distilbert-base-uncased" | |
MODEL_ONNX_FNAME = "ESG_classifier_batch.onnx" | |
MODEL_SENTIMENT_ANALYSIS = "ProsusAI/finbert" | |
#MODEL_SUMMARY_PEGASUS = "oMateos2020/pegasus-newsroom-cnn_full-adafactor-bs6" | |
#API_HF_SENTIMENT_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment" | |
def get_company_sectors(extracted_names, threshold=0.95): | |
''' | |
''' | |
from thefuzz import process, fuzz | |
output = [] | |
standard_names_tuples = [] | |
for extracted_name in extracted_names: | |
name_match = process.extractOne(extracted_name, | |
DF_SP500.Name, | |
scorer=fuzz.token_set_ratio) | |
similarity = name_match[1]/100 | |
if similarity >= threshold: | |
standard_names_tuples.append(name_match[:2]) | |
for std_comp_name, _ in standard_names_tuples: | |
sectors = list(DF_SP500[['Name','Sector']].where(DF_SP500.Name == std_comp_name).dropna().itertuples(index=False, name=None)) | |
output += sectors | |
return output | |
def filter_spans(spans, keep_longest=True): | |
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for | |
creating named entities (where one token can only be part of one entity) or | |
when merging spans with `Retokenizer.merge`. When spans overlap, the (first) | |
longest span is preferred over shorter spans. | |
spans (Iterable[Span]): The spans to filter. | |
keep_longest (bool): Specify whether to keep longer or shorter spans. | |
RETURNS (List[Span]): The filtered spans. | |
""" | |
get_sort_key = lambda span: (span.end - span.start, -span.start) | |
sorted_spans = sorted(spans, key=get_sort_key, reverse=keep_longest) | |
#print(f'sorted_spans: {sorted_spans}') | |
result = [] | |
seen_tokens = set() | |
for span in sorted_spans: | |
# Check for end - 1 here because boundaries are inclusive | |
if span.start not in seen_tokens and span.end - 1 not in seen_tokens: | |
result.append(span) | |
seen_tokens.update(range(span.start, span.end)) | |
result = sorted(result, key=lambda span: span.start) | |
return result | |
def _inference_ner_spancat(text, limit_outputs=10): | |
nlp = spacy.load("en_pipeline") | |
out = [] | |
for doc in nlp.pipe(text): | |
spans = doc.spans["sc"] | |
#comp_raw_text = dict( sorted( dict(zip([str(x) for x in spans],[float(x)*penalty for x in spans.attrs['scores']])).items(), key=lambda x: x[1], reverse=True) ) | |
company_list = list(set([str(span).replace('\'s', '') for span in filter_spans(spans, keep_longest=True)]))[:limit_outputs] | |
out.append(get_company_sectors(company_list)) | |
return out | |
#def _inference_summary_model_pipeline(text): | |
# pipe = pipeline("text2text-generation", model=MODEL_SUMMARY_PEGASUS) | |
# return pipe(text,truncation='longest_first') | |
def _inference_sentiment_model_pipeline(text): | |
tokenizer_kwargs = {'padding':True,'truncation':True,'max_length':512}#,'return_tensors':'pt'} | |
pipe = pipeline("sentiment-analysis", model=MODEL_SENTIMENT_ANALYSIS ) | |
return pipe(text,**tokenizer_kwargs) | |
#def _inference_sentiment_model_via_api_query(payload): | |
# response = requests.post(API_HF_SENTIMENT_URL , headers={"Authorization": os.environ['hf_api_token']}, json=payload) | |
# return response.json() | |
def _lematise_text(text): | |
nlp = spacy.load("en_core_web_sm", disable=['ner']) | |
text_out = [] | |
for doc in nlp.pipe(text): #see https://spacy.io/models#design | |
new_text = "" | |
for token in doc: | |
if (not token.is_punct | |
and not token.is_stop | |
and not token.like_url | |
and not token.is_space | |
and not token.like_email | |
#and not token.like_num | |
and not token.pos_ == "CONJ"): | |
new_text = new_text + " " + token.lemma_ | |
text_out.append( new_text ) | |
return text_out | |
def sigmoid(x): | |
return 1 / (1 + np.exp(-x)) | |
def to_numpy(tensor): | |
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
def is_in_archive(url): | |
try: | |
r = requests.get('http://archive.org/wayback/available?url='+url) | |
archive = json.loads(r.text) | |
if archive['archived_snapshots'] : | |
archive['archived_snapshots']['closest'] | |
return {'archived':archive['archived_snapshots']['closest']['available'], 'url':archive['archived_snapshots']['closest']['url'],'error':0} | |
else: | |
return {'archived':False, 'url':"", 'error':0} | |
except: | |
print(f"[E] Quering URL ({url}) from archive.org") | |
return {'archived':False, 'url':"", 'error':-1} | |
#def _inference_ner(text): | |
# return labels | |
def _inference_classifier(text): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_TRANSFORMER_BASED) | |
inputs = tokenizer(_lematise_text(text), return_tensors="np", padding="max_length", truncation=True) #this assumes head-only! | |
ort_session = onnxruntime.InferenceSession(MODEL_ONNX_FNAME) | |
onnx_model = onnx.load(MODEL_ONNX_FNAME) | |
onnx.checker.check_model(onnx_model) | |
# compute ONNX Runtime output prediction | |
ort_outs = ort_session.run(None, input_feed=dict(inputs)) | |
return sigmoid(ort_outs[0]) | |
def inference(input_batch,isurl,use_archive,limit_companies=10): | |
url_list = [] #Only used if isurl | |
input_batch_content = [] | |
# if file_in.name is not "": | |
# print("[i] Input is file:",file_in.name) | |
# dft = pd.read_csv( | |
# file_in.name, | |
# compression=dict(method='zip') | |
# ) | |
# assert file_col_name in dft.columns, "Indicated col_name not found in file" | |
# input_batch_r = dft[file_col_name].values.tolist() | |
# else: | |
print("[i] Input is list") | |
assert len(input_batch) > 0, "input_batch array is empty" | |
input_batch_r = input_batch | |
print("[i] Input size:",len(input_batch_r)) | |
if isurl: | |
print("[i] Data is URL") | |
if use_archive: | |
print("[i] Use chached URL from archive.org") | |
for row_in in input_batch_r: | |
if isinstance(row_in , list): | |
url = row_in[0] | |
else: | |
url = row_in | |
url_list.append(url) | |
if use_archive: | |
archive = is_in_archive(url) | |
if archive['archived']: | |
url = archive['url'] | |
#Extract the data from url | |
extracted = Extractor().extract(requests.get(url).text) | |
input_batch_content.append(extracted['content']) | |
else: | |
print("[i] Data is news contents") | |
if isinstance(input_batch_r[0], list): | |
print("[i] Data is list of lists format") | |
for row_in in input_batch_r: | |
input_batch_content.append(row_in[0]) | |
else: | |
print("[i] Data is single list format") | |
input_batch_content = input_batch_r | |
print("[i] Batch size:",len(input_batch_content)) | |
print("[i] Running ESG classifier inference...") | |
prob_outs = _inference_classifier(input_batch_content) | |
print("[i] Classifier output shape:",prob_outs.shape) | |
print("[i] Running sentiment using",MODEL_SENTIMENT_ANALYSIS ,"inference...") | |
#sentiment = _inference_sentiment_model_via_api_query({"inputs": extracted['content']}) | |
sentiment = _inference_sentiment_model_pipeline(input_batch_content ) | |
print("[i] Running NER using custom spancat inference...") | |
#summary = _inference_summary_model_pipeline(input_batch_content )[0]['generated_text'] | |
ner_labels = _inference_ner_spancat(input_batch_content ,limit_outputs=limit_companies) | |
print(ner_labels) | |
df = pd.DataFrame(prob_outs,columns =['E','S','G']) | |
if isurl: | |
df['URL'] = url_list | |
else: | |
df['content_id'] = range(1, len(input_batch_r)+1) | |
df['sent_lbl'] = [d['label'] for d in sentiment ] | |
df['sent_score'] = [d['score'] for d in sentiment ] | |
print("[i] Pandas output shape:",df.shape) | |
return df #ner_labels, {'E':float(prob_outs[0]),"S":float(prob_outs[1]),"G":float(prob_outs[2])},{sentiment['label']:float(sentiment['score'])},"**Summary:**\n\n" + summary | |
title = "ESG API Demo" | |
description = """This is a demonstration of the full ESG pipeline backend where given a list of URL (english, news) the news contents are extracted, using extractnet, and fed to three models: | |
- An off-the-shelf sentiment classification model (ProsusAI/finbert) | |
- A custom NER for the company extraction | |
- A custom ESG classifier for the ESG labeling of the news (the extracted text is also lemmatised prior to be fed to this classifier) | |
API input parameters: | |
- List: list of text. Either list of Url of the news (english) or list of extracted news contents | |
- 'Data type': int. 0=list is of extracted news contents, 1=list is of urls. | |
- `use_archive`: boolean. The model will extract the archived version in archive.org of the url indicated. This is useful with old news and to bypass news behind paywall | |
- `limit_companies`: integer. Number of found relevant companies to report. | |
""" | |
examples = [[ [['https://www.bbc.com/news/uk-62732447'], | |
['https://www.bbc.com/news/business-62747401'], | |
['https://www.bbc.com/news/technology-62744858'], | |
['https://www.bbc.com/news/science-environment-62758811'], | |
['https://www.theguardian.com/business/2022/sep/02/nord-stream-1-gazprom-announces-indefinite-shutdown-of-pipeline'], | |
['https://www.bbc.com/news/world-europe-62766867'], | |
['https://www.bbc.com/news/business-62524031'], | |
['https://www.bbc.com/news/business-62728621'], | |
['https://www.bbc.com/news/science-environment-62680423']],'url',False,5]] | |
demo = gr.Interface(fn=inference, | |
inputs=[gr.Dataframe(label='input batch', col_count=1, datatype='str', type='array', wrap=True), | |
gr.Dropdown(label='data type', choices=['text','url'], type='index', value='url'), | |
gr.Checkbox(label='if url parse cached in archive.org'), | |
gr.Slider(minimum=1, maximum=10, step=1, label='Limit NER output', value=5)], | |
outputs=[gr.Dataframe(label='output raw', col_count=1, type='pandas', wrap=True, header=OUT_HEADERS)], | |
#gr.Label(label='Company'), | |
#gr.Label(label='ESG'), | |
#gr.Label(label='Sentiment'), | |
#gr.Markdown()], | |
title=title, | |
description=description, | |
examples=examples) | |
demo.launch() | |