rdose commited on
Commit
0107ad0
·
1 Parent(s): 0afcea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -49
app.py CHANGED
@@ -17,6 +17,8 @@ MODEL_ONNX_FNAME = "ESG_classifier.onnx"
17
  MODEL_SENTIMENT_ANALYSIS = "ProsusAI/finbert"
18
  MODEL_SUMMARY_PEGASUS = "oMateos2020/pegasus-newsroom-cnn_full-adafactor-bs6"
19
 
 
 
20
  #API_HF_SENTIMENT_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment"
21
 
22
  def _inference_ner_spancat(text, summary, penalty=0.5, normalise=True, limit_outputs=10):
@@ -51,32 +53,24 @@ def _inference_sentiment_model_pipeline(text):
51
  # response = requests.post(API_HF_SENTIMENT_URL , headers={"Authorization": os.environ['hf_api_token']}, json=payload)
52
  # return response.json()
53
 
54
- def convert_listwords_text(list_words):
55
- text = ""
56
- for word in list_words:
57
- text = text + " " + word
58
- return text
59
-
60
- def clean_text(text):
61
- nlp = spacy.load("en_core_web_sm")
62
- nlp.max_length=2000000
63
- if (text != ""):
64
- list_word = []
65
-
66
- for token in nlp(text):
67
- if (not token.is_punct
68
- and not token.is_stop
69
- and not token.like_url
70
- and not token.is_space
71
- and not token.like_email
72
- #and not token.like_num
73
- and not token.pos_ == "CONJ"):
74
-
75
- list_word.append(token.lemma_)
76
-
77
- return convert_listwords_text(list_words=list_word)
78
- else:
79
- return -1
80
 
81
  def sigmoid(x):
82
  return 1 / (1 + np.exp(-x))
@@ -103,7 +97,7 @@ def is_in_archive(url):
103
 
104
  def _inference_classifier(text):
105
  tokenizer = AutoTokenizer.from_pretrained(MODEL_TRANSFORMER_BASED)
106
- inputs = tokenizer(clean_text(text), return_tensors="np", padding="max_length", truncation=True) #this assumes head-only!
107
  ort_session = onnxruntime.InferenceSession(MODEL_ONNX_FNAME)
108
  onnx_model = onnx.load(MODEL_ONNX_FNAME)
109
  onnx.checker.check_model(onnx_model)
@@ -113,20 +107,27 @@ def _inference_classifier(text):
113
 
114
  return sigmoid(ort_outs[0])[0]
115
 
116
- def inference(url,use_archive,limit_companies=10):
117
- if use_archive:
118
- archive = is_in_archive(url)
119
- if archive['archived']:
120
- url = archive['url']
121
- #Extract the data from url
122
- extracted = Extractor().extract(requests.get(url).text)
123
- prob_outs = _inference_classifier(extracted['content'])
 
 
 
 
 
 
 
124
  #sentiment = _inference_sentiment_model_via_api_query({"inputs": extracted['content']})
125
- sentiment = _inference_sentiment_model_pipeline(extracted['content'])[0]
126
- summary = _inference_summary_model_pipeline(extracted['content'])[0]['generated_text']
127
- ner_labels = _inference_ner_spancat(extracted['content'],summary, penalty = 0.8, limit_outputs=limit_companies)
128
 
129
- return ner_labels, {'E':float(prob_outs[0]),"S":float(prob_outs[1]),"G":float(prob_outs[2])},{sentiment['label']:float(sentiment['score'])},"**Summary:**\n\n" + summary
130
 
131
  title = "ESG API Demo"
132
  description = """This is a demonstration of the full ESG pipeline backend where given a URL (english, news) the news contents are extracted, using extractnet, and fed to three models:
@@ -141,14 +142,25 @@ API input parameters:
141
  - `limit_companies`: integer. Number of found relevant companies to report.
142
 
143
  """
144
- examples = [['https://www.bbc.com/news/uk-62732447',False,5],
145
- ['https://www.bbc.com/news/business-62747401',False,5],
146
- ['https://www.bbc.com/news/technology-62744858',False,5],
147
- ['https://www.bbc.com/news/science-environment-62758811',False,5],
148
- ['https://www.theguardian.com/business/2022/sep/02/nord-stream-1-gazprom-announces-indefinite-shutdown-of-pipeline',False,5],
149
- ['https://www.bbc.com/news/world-europe-62766867',False,5],
150
- ['https://www.bbc.com/news/business-62524031',False,5],
151
- ['https://www.bbc.com/news/business-62728621',False,5],
152
- ['https://www.bbc.com/news/science-environment-62680423',False,5]]
153
- demo = gr.Interface(fn=inference, inputs=[gr.Textbox(label='URL'),gr.Checkbox(label='grab cached from archive.org'), gr.Slider(minimum=1, maximum=10, step=1, label='Limit NER output')], outputs=[gr.Label(label='Company'), gr.Label(label='ESG'),gr.Label(label='Sentiment'),gr.Markdown()], title=title, description=description, examples=examples)
 
 
 
 
 
 
 
 
 
 
 
154
  demo.launch()
 
17
  MODEL_SENTIMENT_ANALYSIS = "ProsusAI/finbert"
18
  MODEL_SUMMARY_PEGASUS = "oMateos2020/pegasus-newsroom-cnn_full-adafactor-bs6"
19
 
20
+
21
+
22
  #API_HF_SENTIMENT_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment"
23
 
24
  def _inference_ner_spancat(text, summary, penalty=0.5, normalise=True, limit_outputs=10):
 
53
  # response = requests.post(API_HF_SENTIMENT_URL , headers={"Authorization": os.environ['hf_api_token']}, json=payload)
54
  # return response.json()
55
 
56
+ def _lematise_text(text):
57
+ nlp = spacy.load("en_core_web_sm", disable=['ner'])
58
+ text_out = []
59
+ for doc in nlp.pipe(text): #see https://spacy.io/models#design
60
+ new_text = ""
61
+ for token in doc:
62
+ if (not token.is_punct
63
+ and not token.is_stop
64
+ and not token.like_url
65
+ and not token.is_space
66
+ and not token.like_email
67
+ #and not token.like_num
68
+ and not token.pos_ == "CONJ"):
69
+
70
+ new_text = new_text + " " + token.lemma_
71
+
72
+ text_out.append( new_text )
73
+ return text_out
 
 
 
 
 
 
 
 
74
 
75
  def sigmoid(x):
76
  return 1 / (1 + np.exp(-x))
 
97
 
98
  def _inference_classifier(text):
99
  tokenizer = AutoTokenizer.from_pretrained(MODEL_TRANSFORMER_BASED)
100
+ inputs = tokenizer(_lematise_text(text), return_tensors="np", padding="max_length", truncation=True) #this assumes head-only!
101
  ort_session = onnxruntime.InferenceSession(MODEL_ONNX_FNAME)
102
  onnx_model = onnx.load(MODEL_ONNX_FNAME)
103
  onnx.checker.check_model(onnx_model)
 
107
 
108
  return sigmoid(ort_outs[0])[0]
109
 
110
+ def inference(input_batch,isurl,use_archive,limit_companies=10):
111
+ input_batch_content = []
112
+ if isurl:
113
+ for url in input_batch:
114
+ if use_archive:
115
+ archive = is_in_archive(url)
116
+ if archive['archived']:
117
+ url = archive['url']
118
+ #Extract the data from url
119
+ extracted = Extractor().extract(requests.get(url).text)
120
+ input_batch_content.append(extracted['content'])
121
+ else:
122
+ input_batch_content = input_batch
123
+
124
+ prob_outs = _inference_classifier(input_batch_content)
125
  #sentiment = _inference_sentiment_model_via_api_query({"inputs": extracted['content']})
126
+ #sentiment = _inference_sentiment_model_pipeline(input_batch_content )[0]
127
+ #summary = _inference_summary_model_pipeline(input_batch_content )[0]['generated_text']
128
+ #ner_labels = _inference_ner_spancat(input_batch_content ,summary, penalty = 0.8, limit_outputs=limit_companies)
129
 
130
+ return prob_outs #ner_labels, {'E':float(prob_outs[0]),"S":float(prob_outs[1]),"G":float(prob_outs[2])},{sentiment['label']:float(sentiment['score'])},"**Summary:**\n\n" + summary
131
 
132
  title = "ESG API Demo"
133
  description = """This is a demonstration of the full ESG pipeline backend where given a URL (english, news) the news contents are extracted, using extractnet, and fed to three models:
 
142
  - `limit_companies`: integer. Number of found relevant companies to report.
143
 
144
  """
145
+ #examples = [['https://www.bbc.com/news/uk-62732447',False,5],
146
+ # ['https://www.bbc.com/news/business-62747401',False,5],
147
+ # ['https://www.bbc.com/news/technology-62744858',False,5],
148
+ # ['https://www.bbc.com/news/science-environment-62758811',False,5],
149
+ # ['https://www.theguardian.com/business/2022/sep/02/nord-stream-1-gazprom-announces-indefinite-shutdown-of-pipeline',False,5],
150
+ # ['https://www.bbc.com/news/world-europe-62766867',False,5],
151
+ # ['https://www.bbc.com/news/business-62524031',False,5],
152
+ # ['https://www.bbc.com/news/business-62728621',False,5],
153
+ # ['https://www.bbc.com/news/science-environment-62680423',False,5]]
154
+ demo = gr.Interface(fn=inference,
155
+ inputs=[gr.Dataframe(label='input batch', col_count=1, datatype='str', type='array', wrap=True),
156
+ gr.Dropdown(label='data type', choices=['text','url'], type='index'),
157
+ gr.Checkbox(label='if url parse cached in archive.org'),
158
+ gr.Slider(minimum=1, maximum=10, step=1, label='Limit NER output')],
159
+ outputs=[gr.Dataframe(label='output raw', col_count=1, datatype='number', type='array', wrap=True)],
160
+ #gr.Label(label='Company'),
161
+ #gr.Label(label='ESG'),
162
+ #gr.Label(label='Sentiment'),
163
+ #gr.Markdown()],
164
+ title=title,
165
+ description=description)#, examples=examples)
166
  demo.launch()