wissamantoun commited on
Commit
0558cbb
1 Parent(s): 854b7af

added Sentiment Analysis

Browse files
Files changed (6) hide show
  1. app.py +2 -0
  2. backend/sa.py +19 -0
  3. backend/sa_utils.py +510 -0
  4. backend/services.py +177 -0
  5. backend/utils.py +10 -0
  6. requirements.txt +3 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import streamlit as st
4
  import backend.aragpt
5
  import backend.home
6
  import backend.processor
 
7
  from backend.utils import get_current_ram_usage
8
 
9
  st.set_page_config(
@@ -14,6 +15,7 @@ PAGES = {
14
  "Home": backend.home,
15
  "Arabic Text Preprocessor": backend.processor,
16
  "Arabic Language Generation": backend.aragpt,
 
17
  }
18
 
19
 
 
4
  import backend.aragpt
5
  import backend.home
6
  import backend.processor
7
+ import backend.sa
8
  from backend.utils import get_current_ram_usage
9
 
10
  st.set_page_config(
 
15
  "Home": backend.home,
16
  "Arabic Text Preprocessor": backend.processor,
17
  "Arabic Language Generation": backend.aragpt,
18
+ "Arabic Sentiment Analysis": backend.sa,
19
  }
20
 
21
 
backend/sa.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .services import SentimentAnalyzer
3
+ from functools import lru_cache
4
+
5
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
6
+ @lru_cache(maxsize=1)
7
+ def load_text_generator():
8
+ predictor = SentimentAnalyzer()
9
+ return predictor
10
+
11
+
12
+ predictor = load_text_generator()
13
+
14
+
15
+ def write():
16
+ input_text = st.text_input("Enter your text here:", key="Fuck you")
17
+ if st.button("Predict"):
18
+ with st.spinner("Predicting..."):
19
+ prediction, score, all_score = predictor.predict([input_text])
backend/sa_utils.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from contextlib import contextmanager
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from fuzzysearch import find_near_matches
8
+ from pyarabic import araby
9
+ from torch import nn
10
+ from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, pipeline
11
+ from transformers.modeling_outputs import SequenceClassifierOutput
12
+
13
+ from .preprocess import ArabertPreprocessor, url_regexes, user_mention_regex
14
+
15
+ multiple_char_pattern = re.compile(r"(.)\1{2,}", re.DOTALL)
16
+
17
+ # ASAD-NEW_AraBERT_PREP-Balanced
18
+ class NewArabicPreprocessorBalanced(ArabertPreprocessor):
19
+ def __init__(
20
+ self,
21
+ model_name: str,
22
+ keep_emojis: bool = False,
23
+ remove_html_markup: bool = True,
24
+ replace_urls_emails_mentions: bool = True,
25
+ strip_tashkeel: bool = True,
26
+ strip_tatweel: bool = True,
27
+ insert_white_spaces: bool = True,
28
+ remove_non_digit_repetition: bool = True,
29
+ replace_slash_with_dash: bool = None,
30
+ map_hindi_numbers_to_arabic: bool = None,
31
+ apply_farasa_segmentation: bool = None,
32
+ ):
33
+ if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
34
+ keep_emojis = True
35
+ remove_non_digit_repetition = True
36
+ super().__init__(
37
+ model_name=model_name,
38
+ keep_emojis=keep_emojis,
39
+ remove_html_markup=remove_html_markup,
40
+ replace_urls_emails_mentions=replace_urls_emails_mentions,
41
+ strip_tashkeel=strip_tashkeel,
42
+ strip_tatweel=strip_tatweel,
43
+ insert_white_spaces=insert_white_spaces,
44
+ remove_non_digit_repetition=remove_non_digit_repetition,
45
+ replace_slash_with_dash=replace_slash_with_dash,
46
+ map_hindi_numbers_to_arabic=map_hindi_numbers_to_arabic,
47
+ apply_farasa_segmentation=apply_farasa_segmentation,
48
+ )
49
+ self.true_model_name = model_name
50
+
51
+ def preprocess(self, text):
52
+ if "UBC-NLP" in self.true_model_name:
53
+ return self.ubc_prep(text)
54
+
55
+ def ubc_prep(self, text):
56
+ text = re.sub("\s", " ", text)
57
+ text = text.replace("\\n", " ")
58
+ text = text.replace("\\r", " ")
59
+ text = araby.strip_tashkeel(text)
60
+ text = araby.strip_tatweel(text)
61
+ # replace all possible URLs
62
+ for reg in url_regexes:
63
+ text = re.sub(reg, " URL ", text)
64
+ text = re.sub("(URL\s*)+", " URL ", text)
65
+ # replace mentions with USER
66
+ text = re.sub(user_mention_regex, " USER ", text)
67
+ text = re.sub("(USER\s*)+", " USER ", text)
68
+ # replace hashtags with HASHTAG
69
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
70
+ text = text.replace("#", " HASH ")
71
+ text = text.replace("_", " ")
72
+ text = " ".join(text.split())
73
+ # text = re.sub("\B\\[Uu]\w+", "", text)
74
+ text = text.replace("\\U0001f97a", "🥺")
75
+ text = text.replace("\\U0001f928", "🤨")
76
+ text = text.replace("\\U0001f9d8", "😀")
77
+ text = text.replace("\\U0001f975", "😥")
78
+ text = text.replace("\\U0001f92f", "😲")
79
+ text = text.replace("\\U0001f92d", "🤭")
80
+ text = text.replace("\\U0001f9d1", "😐")
81
+ text = text.replace("\\U000e0067", "")
82
+ text = text.replace("\\U000e006e", "")
83
+ text = text.replace("\\U0001f90d", "♥")
84
+ text = text.replace("\\U0001f973", "🎉")
85
+ text = text.replace("\\U0001fa79", "")
86
+ text = text.replace("\\U0001f92b", "🤐")
87
+ text = text.replace("\\U0001f9da", "🦋")
88
+ text = text.replace("\\U0001f90e", "♥")
89
+ text = text.replace("\\U0001f9d0", "🧐")
90
+ text = text.replace("\\U0001f9cf", "")
91
+ text = text.replace("\\U0001f92c", "😠")
92
+ text = text.replace("\\U0001f9f8", "😸")
93
+ text = text.replace("\\U0001f9b6", "💩")
94
+ text = text.replace("\\U0001f932", "🤲")
95
+ text = text.replace("\\U0001f9e1", "🧡")
96
+ text = text.replace("\\U0001f974", "☹")
97
+ text = text.replace("\\U0001f91f", "")
98
+ text = text.replace("\\U0001f9fb", "💩")
99
+ text = text.replace("\\U0001f92a", "🤪")
100
+ text = text.replace("\\U0001f9fc", "")
101
+ text = text.replace("\\U000e0065", "")
102
+ text = text.replace("\\U0001f92e", "💩")
103
+ text = text.replace("\\U000e007f", "")
104
+ text = text.replace("\\U0001f970", "🥰")
105
+ text = text.replace("\\U0001f929", "🤩")
106
+ text = text.replace("\\U0001f6f9", "")
107
+ text = text.replace("🤍", "♥")
108
+ text = text.replace("🦠", "😷")
109
+ text = text.replace("🤢", "مقرف")
110
+ text = text.replace("🤮", "مقرف")
111
+ text = text.replace("🕠", "⌚")
112
+ text = text.replace("🤬", "😠")
113
+ text = text.replace("🤧", "😷")
114
+ text = text.replace("🥳", "🎉")
115
+ text = text.replace("🥵", "🔥")
116
+ text = text.replace("🥴", "☹")
117
+ text = text.replace("🤫", "🤐")
118
+ text = text.replace("🤥", "كذاب")
119
+ text = text.replace("\\u200d", " ")
120
+ text = text.replace("u200d", " ")
121
+ text = text.replace("\\u200c", " ")
122
+ text = text.replace("u200c", " ")
123
+ text = text.replace('"', "'")
124
+ text = text.replace("\\xa0", "")
125
+ text = text.replace("\\u2066", " ")
126
+ text = re.sub("\B\\\[Uu]\w+", "", text)
127
+ text = super(NewArabicPreprocessorBalanced, self).preprocess(text)
128
+
129
+ text = " ".join(text.split())
130
+ return text
131
+
132
+
133
+ """CNNMarbertArabicPreprocessor"""
134
+ # ASAD-CNN_MARBERT
135
+ class CNNMarbertArabicPreprocessor(ArabertPreprocessor):
136
+ def __init__(
137
+ self,
138
+ model_name,
139
+ keep_emojis=False,
140
+ remove_html_markup=True,
141
+ replace_urls_emails_mentions=True,
142
+ remove_elongations=True,
143
+ ):
144
+ if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
145
+ keep_emojis = True
146
+ remove_elongations = False
147
+ super().__init__(
148
+ model_name,
149
+ keep_emojis,
150
+ remove_html_markup,
151
+ replace_urls_emails_mentions,
152
+ remove_elongations,
153
+ )
154
+ self.true_model_name = model_name
155
+
156
+ def preprocess(self, text):
157
+ if "UBC-NLP" in self.true_model_name:
158
+ return self.ubc_prep(text)
159
+
160
+ def ubc_prep(self, text):
161
+ text = re.sub("\s", " ", text)
162
+ text = text.replace("\\n", " ")
163
+ text = araby.strip_tashkeel(text)
164
+ text = araby.strip_tatweel(text)
165
+ # replace all possible URLs
166
+ for reg in url_regexes:
167
+ text = re.sub(reg, " URL ", text)
168
+ text = re.sub("(URL\s*)+", " URL ", text)
169
+ # replace mentions with USER
170
+ text = re.sub(user_mention_regex, " USER ", text)
171
+ text = re.sub("(USER\s*)+", " USER ", text)
172
+ # replace hashtags with HASHTAG
173
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
174
+ text = text.replace("#", " HASH ")
175
+ text = text.replace("_", " ")
176
+ text = " ".join(text.split())
177
+ text = super(CNNMarbertArabicPreprocessor, self).preprocess(text)
178
+ text = text.replace("\u200d", " ")
179
+ text = text.replace("u200d", " ")
180
+ text = text.replace("\u200c", " ")
181
+ text = text.replace("u200c", " ")
182
+ text = text.replace('"', "'")
183
+ # text = re.sub('[\d\.]+', ' NUM ', text)
184
+ # text = re.sub('(NUM\s*)+', ' NUM ', text)
185
+ text = multiple_char_pattern.sub(r"\1\1", text)
186
+ text = " ".join(text.split())
187
+ return text
188
+
189
+
190
+ """Trial5ArabicPreprocessor"""
191
+
192
+
193
+ class Trial5ArabicPreprocessor(ArabertPreprocessor):
194
+ def __init__(
195
+ self,
196
+ model_name,
197
+ keep_emojis=False,
198
+ remove_html_markup=True,
199
+ replace_urls_emails_mentions=True,
200
+ ):
201
+ if "UBC-NLP" in model_name:
202
+ keep_emojis = True
203
+ super().__init__(
204
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
205
+ )
206
+ self.true_model_name = model_name
207
+
208
+ def preprocess(self, text):
209
+ if "UBC-NLP" in self.true_model_name:
210
+ return self.ubc_prep(text)
211
+
212
+ def ubc_prep(self, text):
213
+ text = re.sub("\s", " ", text)
214
+ text = text.replace("\\n", " ")
215
+ text = araby.strip_tashkeel(text)
216
+ text = araby.strip_tatweel(text)
217
+ # replace all possible URLs
218
+ for reg in url_regexes:
219
+ text = re.sub(reg, " URL ", text)
220
+ # replace mentions with USER
221
+ text = re.sub(user_mention_regex, " USER ", text)
222
+ # replace hashtags with HASHTAG
223
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
224
+ text = text.replace("#", " HASH TAG ")
225
+ text = text.replace("_", " ")
226
+ text = " ".join(text.split())
227
+ text = super(Trial5ArabicPreprocessor, self).preprocess(text)
228
+ # text = text.replace("السلام عليكم"," ")
229
+ # text = text.replace(find_near_matches("السلام عليكم",text,max_deletions=3,max_l_dist=3)[0].matched," ")
230
+ return text
231
+
232
+
233
+ """SarcasmArabicPreprocessor"""
234
+
235
+
236
+ class SarcasmArabicPreprocessor(ArabertPreprocessor):
237
+ def __init__(
238
+ self,
239
+ model_name,
240
+ keep_emojis=False,
241
+ remove_html_markup=True,
242
+ replace_urls_emails_mentions=True,
243
+ ):
244
+ if "UBC-NLP" in model_name:
245
+ keep_emojis = True
246
+ super().__init__(
247
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
248
+ )
249
+ self.true_model_name = model_name
250
+
251
+ def preprocess(self, text):
252
+ if "UBC-NLP" in self.true_model_name:
253
+ return self.ubc_prep(text)
254
+ else:
255
+ return super(SarcasmArabicPreprocessor, self).preprocess(text)
256
+
257
+ def ubc_prep(self, text):
258
+ text = re.sub("\s", " ", text)
259
+ text = text.replace("\\n", " ")
260
+ text = araby.strip_tashkeel(text)
261
+ text = araby.strip_tatweel(text)
262
+ # replace all possible URLs
263
+ for reg in url_regexes:
264
+ text = re.sub(reg, " URL ", text)
265
+ # replace mentions with USER
266
+ text = re.sub(user_mention_regex, " USER ", text)
267
+ # replace hashtags with HASHTAG
268
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
269
+ text = text.replace("#", " HASH TAG ")
270
+ text = text.replace("_", " ")
271
+ text = text.replace('"', " ")
272
+ text = " ".join(text.split())
273
+ text = super(SarcasmArabicPreprocessor, self).preprocess(text)
274
+ return text
275
+
276
+
277
+ """NoAOAArabicPreprocessor"""
278
+
279
+
280
+ class NoAOAArabicPreprocessor(ArabertPreprocessor):
281
+ def __init__(
282
+ self,
283
+ model_name,
284
+ keep_emojis=False,
285
+ remove_html_markup=True,
286
+ replace_urls_emails_mentions=True,
287
+ ):
288
+ if "UBC-NLP" in model_name:
289
+ keep_emojis = True
290
+ super().__init__(
291
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
292
+ )
293
+ self.true_model_name = model_name
294
+
295
+ def preprocess(self, text):
296
+ if "UBC-NLP" in self.true_model_name:
297
+ return self.ubc_prep(text)
298
+ else:
299
+ return super(NoAOAArabicPreprocessor, self).preprocess(text)
300
+
301
+ def ubc_prep(self, text):
302
+ text = re.sub("\s", " ", text)
303
+ text = text.replace("\\n", " ")
304
+ text = araby.strip_tashkeel(text)
305
+ text = araby.strip_tatweel(text)
306
+ # replace all possible URLs
307
+ for reg in url_regexes:
308
+ text = re.sub(reg, " URL ", text)
309
+ # replace mentions with USER
310
+ text = re.sub(user_mention_regex, " USER ", text)
311
+ # replace hashtags with HASHTAG
312
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
313
+ text = text.replace("#", " HASH TAG ")
314
+ text = text.replace("_", " ")
315
+ text = " ".join(text.split())
316
+ text = super(NoAOAArabicPreprocessor, self).preprocess(text)
317
+ text = text.replace("السلام عليكم", " ")
318
+ text = text.replace("ورحمة الله وبركاته", " ")
319
+ matched = find_near_matches("السلام عليكم", text, max_deletions=3, max_l_dist=3)
320
+ if len(matched) > 0:
321
+ text = text.replace(matched[0].matched, " ")
322
+ matched = find_near_matches(
323
+ "ورحمة الله وبركاته", text, max_deletions=3, max_l_dist=3
324
+ )
325
+ if len(matched) > 0:
326
+ text = text.replace(matched[0].matched, " ")
327
+ return text
328
+
329
+
330
+ class CnnBertForSequenceClassification(BertPreTrainedModel):
331
+ def __init__(self, config):
332
+ super().__init__(config)
333
+ self.num_labels = config.num_labels
334
+ self.config = config
335
+
336
+ self.bert = BertModel(config)
337
+
338
+ filter_sizes = [1, 2, 3, 4, 5]
339
+ num_filters = 32
340
+ self.convs1 = nn.ModuleList(
341
+ [nn.Conv2d(4, num_filters, (K, config.hidden_size)) for K in filter_sizes]
342
+ )
343
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
344
+ self.classifier = nn.Linear(len(filter_sizes) * num_filters, config.num_labels)
345
+
346
+ self.init_weights()
347
+
348
+ def forward(
349
+ self,
350
+ input_ids=None,
351
+ attention_mask=None,
352
+ token_type_ids=None,
353
+ position_ids=None,
354
+ head_mask=None,
355
+ inputs_embeds=None,
356
+ labels=None,
357
+ output_attentions=None,
358
+ output_hidden_states=None,
359
+ return_dict=None,
360
+ ):
361
+
362
+ return_dict = (
363
+ return_dict if return_dict is not None else self.config.use_return_dict
364
+ )
365
+
366
+ outputs = self.bert(
367
+ input_ids,
368
+ attention_mask=attention_mask,
369
+ token_type_ids=token_type_ids,
370
+ position_ids=position_ids,
371
+ head_mask=head_mask,
372
+ inputs_embeds=inputs_embeds,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ )
377
+
378
+ x = outputs[2][-4:]
379
+
380
+ x = torch.stack(x, dim=1)
381
+ x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
382
+ x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
383
+ x = torch.cat(x, 1)
384
+ x = self.dropout(x)
385
+ logits = self.classifier(x)
386
+
387
+ loss = None
388
+ if labels is not None:
389
+ if self.config.problem_type is None:
390
+ if self.num_labels == 1:
391
+ self.config.problem_type = "regression"
392
+ elif self.num_labels > 1 and (
393
+ labels.dtype == torch.long or labels.dtype == torch.int
394
+ ):
395
+ self.config.problem_type = "single_label_classification"
396
+ else:
397
+ self.config.problem_type = "multi_label_classification"
398
+
399
+ if self.config.problem_type == "regression":
400
+ loss_fct = nn.MSELoss()
401
+ if self.num_labels == 1:
402
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
403
+ else:
404
+ loss = loss_fct(logits, labels)
405
+ elif self.config.problem_type == "single_label_classification":
406
+ loss_fct = nn.CrossEntropyLoss()
407
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
408
+ elif self.config.problem_type == "multi_label_classification":
409
+ loss_fct = nn.BCEWithLogitsLoss()
410
+ loss = loss_fct(logits, labels)
411
+ if not return_dict:
412
+ output = (logits,) + outputs[2:]
413
+ return ((loss,) + output) if loss is not None else output
414
+
415
+ return SequenceClassifierOutput(
416
+ loss=loss,
417
+ logits=logits,
418
+ hidden_states=None,
419
+ attentions=outputs.attentions,
420
+ )
421
+
422
+
423
+ class CNNTextClassificationPipeline:
424
+ def __init__(self, model_path, device, return_all_scores=False):
425
+ self.model_path = model_path
426
+ self.model = CnnBertForSequenceClassification.from_pretrained(self.model_path)
427
+ # Special handling
428
+ self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
429
+ if self.device.type == "cuda":
430
+ self.model = self.model.to(self.device)
431
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
432
+ self.return_all_scores = return_all_scores
433
+
434
+ @contextmanager
435
+ def device_placement(self):
436
+ """
437
+ Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
438
+ Returns:
439
+ Context manager
440
+ Examples::
441
+ # Explicitly ask for tensor allocation on CUDA device :0
442
+ pipe = pipeline(..., device=0)
443
+ with pipe.device_placement():
444
+ # Every framework specific tensor allocation will be done on the request device
445
+ output = pipe(...)
446
+ """
447
+
448
+ if self.device.type == "cuda":
449
+ torch.cuda.set_device(self.device)
450
+
451
+ yield
452
+
453
+ def ensure_tensor_on_device(self, **inputs):
454
+ """
455
+ Ensure PyTorch tensors are on the specified device.
456
+ Args:
457
+ inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.
458
+ Return:
459
+ :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
460
+ """
461
+ return {
462
+ name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
463
+ for name, tensor in inputs.items()
464
+ }
465
+
466
+ def __call__(self, text):
467
+ """
468
+ Classify the text(s) given as inputs.
469
+ Args:
470
+ args (:obj:`str` or :obj:`List[str]`):
471
+ One or several texts (or one list of prompts) to classify.
472
+ Return:
473
+ A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys:
474
+ - **label** (:obj:`str`) -- The label predicted.
475
+ - **score** (:obj:`float`) -- The corresponding probability.
476
+ If ``self.return_all_scores=True``, one such dictionary is returned per label.
477
+ """
478
+ # outputs = super().__call__(*args, **kwargs)
479
+ inputs = self.tokenizer.batch_encode_plus(
480
+ text,
481
+ add_special_tokens=True,
482
+ max_length=64,
483
+ padding=True,
484
+ truncation="longest_first",
485
+ return_tensors="pt",
486
+ )
487
+
488
+ with torch.no_grad():
489
+ inputs = self.ensure_tensor_on_device(**inputs)
490
+ predictions = self.model(**inputs)[0].cpu()
491
+
492
+ predictions = predictions.numpy()
493
+
494
+ if self.model.config.num_labels == 1:
495
+ scores = 1.0 / (1.0 + np.exp(-predictions))
496
+ else:
497
+ scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)
498
+ if self.return_all_scores:
499
+ return [
500
+ [
501
+ {"label": self.model.config.id2label[i], "score": score.item()}
502
+ for i, score in enumerate(item)
503
+ ]
504
+ for item in scores
505
+ ]
506
+ else:
507
+ return [
508
+ {"label": self.inv_label_map[item.argmax()], "score": item.max().item()}
509
+ for item in scores
510
+ ]
backend/services.py CHANGED
@@ -1,9 +1,17 @@
1
  import json
2
  import os
 
 
 
 
3
  import requests
 
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline, set_seed
 
5
  from .modeling_gpt2 import GPT2LMHeadModel as GROVERLMHeadModel
6
  from .preprocess import ArabertPreprocessor
 
 
7
 
8
  # Taken and Modified from https://huggingface.co/spaces/flax-community/chef-transformer/blob/main/app.py
9
  class TextGeneration:
@@ -170,3 +178,172 @@ class TextGeneration:
170
  },
171
  }
172
  return self.query(payload, model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import os
3
+ from typing import List
4
+
5
+ import more_itertools
6
+ import pandas as pd
7
  import requests
8
+ from tqdm.auto import tqdm
9
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline, set_seed
10
+
11
  from .modeling_gpt2 import GPT2LMHeadModel as GROVERLMHeadModel
12
  from .preprocess import ArabertPreprocessor
13
+ from .sa_utils import *
14
+ from .utils import download_models
15
 
16
  # Taken and Modified from https://huggingface.co/spaces/flax-community/chef-transformer/blob/main/app.py
17
  class TextGeneration:
 
178
  },
179
  }
180
  return self.query(payload, model_name)
181
+
182
+
183
+ class SentimentAnalyzer:
184
+ def __init__(self):
185
+ self.sa_models = [
186
+ "sa_trial5_1",
187
+ "sa_no_aoa_in_neutral",
188
+ "sa_cnnbert",
189
+ "sa_sarcasm",
190
+ "sar_trial10",
191
+ "sa_no_AOA",
192
+ ]
193
+ self.model_repos = download_models(self.sa_models)
194
+ # fmt: off
195
+ self.processors = {
196
+ "sa_trial5_1": Trial5ArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
197
+ "sa_no_aoa_in_neutral": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
198
+ "sa_cnnbert": CNNMarbertArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
199
+ "sa_sarcasm": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
200
+ "sar_trial10": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
201
+ "sa_no_AOA": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
202
+ }
203
+
204
+ self.pipelines = {
205
+ "sa_trial5_1": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format(self.model_repos["sa_trial5_1"],i), device=-1,return_all_scores =True) for i in range(0,5)],
206
+ "sa_no_aoa_in_neutral": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format(self.model_repos["sa_no_aoa_in_neutral"],i), device=-1,return_all_scores =True) for i in range(0,5)],
207
+ "sa_cnnbert": [CNNTextClassificationPipeline("{}/train_{}/best_model".format(self.model_repos["sa_cnnbert"],i), device=-1, return_all_scores =True) for i in range(0,5)],
208
+ "sa_sarcasm": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format(self.model_repos["sa_sarcasm"],i), device=-1,return_all_scores =True) for i in range(0,5)],
209
+ "sar_trial10": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format(self.model_repos["sar_trial10"],i), device=-1,return_all_scores =True) for i in range(0,5)],
210
+ "sa_no_AOA": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format(self.model_repos["sa_no_aoa_in_neutral"],i), device=-1,return_all_scores =True) for i in range(0,5)],
211
+ }
212
+ # fmt: on
213
+
214
+ def get_sarcasm_label(self, texts):
215
+ prep = self.processors["sar_trial10"]
216
+ prep_texts = [prep.preprocess(x) for x in texts]
217
+
218
+ preds_df = pd.DataFrame([])
219
+ for i in range(0, 5):
220
+ preds = []
221
+ for s in tqdm(more_itertools.chunked(list(prep_texts), 128)):
222
+ preds.extend(self.pipelines["sar_trial10"][i](s))
223
+ preds_df[f"model_{i}"] = preds
224
+
225
+ final_labels = []
226
+ final_scores = []
227
+ for id, row in preds_df.iterrows():
228
+ pos_total = 0
229
+ neu_total = 0
230
+ for pred in row[:]:
231
+ pos_total += pred[0]["score"]
232
+ neu_total += pred[1]["score"]
233
+
234
+ pos_avg = pos_total / len(row[:])
235
+ neu_avg = neu_total / len(row[:])
236
+
237
+ final_labels.append(
238
+ self.pipelines["sar_trial10"][0].model.config.id2label[
239
+ np.argmax([pos_avg, neu_avg])
240
+ ]
241
+ )
242
+ final_scores.append(np.max([pos_avg, neu_avg]))
243
+
244
+ return final_labels, final_scores
245
+
246
+ def get_preds_from_a_model(self, texts: List[str], model_name):
247
+ prep = self.processors[model_name]
248
+
249
+ prep_texts = [prep.preprocess(x) for x in texts]
250
+ if model_name == "sa_sarcasm":
251
+ sarcasm_label, _ = self.get_preds_from_sarcasm(texts, "sar_trial10")
252
+ sarcastic_map = {"Not_Sarcastic": "غير ساخر", "Sarcastic": "ساخر"}
253
+ labeled_prep_texts = []
254
+ for t, l in zip(prep_texts, sarcasm_label):
255
+ labeled_prep_texts.append(sarcastic_map[l] + " [SEP] " + t)
256
+
257
+ preds_df = pd.DataFrame([])
258
+ for i in range(0, 5):
259
+ preds = []
260
+ for s in tqdm(more_itertools.chunked(list(prep_texts), 128)):
261
+ preds.extend(self.pipelines[model_name][i](s))
262
+ preds_df[f"model_{i}"] = preds
263
+
264
+ final_labels = []
265
+ final_scores = []
266
+ final_scores_list = []
267
+ for id, row in preds_df.iterrows():
268
+ pos_total = 0
269
+ neg_total = 0
270
+ neu_total = 0
271
+ for pred in row[2:]:
272
+ pos_total += pred[0]["score"]
273
+ neu_total += pred[1]["score"]
274
+ neg_total += pred[2]["score"]
275
+
276
+ pos_avg = pos_total / 5
277
+ neu_avg = neu_total / 5
278
+ neg_avg = neg_total / 5
279
+
280
+ if model_name == "sa_no_aoa_in_neutral":
281
+ final_labels.append(
282
+ self.pipelines[model_name][0].model.config.id2label[
283
+ np.argmax([neu_avg, neg_avg, pos_avg])
284
+ ]
285
+ )
286
+ else:
287
+ final_labels.append(
288
+ self.pipelines[model_name][0].model.config.id2label[
289
+ np.argmax([pos_avg, neu_avg, neg_avg])
290
+ ]
291
+ )
292
+ final_scores.append(np.max([pos_avg, neu_avg, neg_avg]))
293
+ final_scores_list.append((pos_avg, neu_avg, neg_avg))
294
+
295
+ return final_labels, final_scores, final_scores_list
296
+
297
+ def predict(self, texts: List[str]):
298
+ (
299
+ new_balanced_label,
300
+ new_balanced_score,
301
+ new_balanced_score_list,
302
+ ) = self.get_preds_from_a_model(texts, "sa_no_aoa_in_neutral")
303
+ (
304
+ cnn_marbert_label,
305
+ cnn_marbert_score,
306
+ cnn_marbert_score_list,
307
+ ) = self.get_preds_from_a_model(texts, "sa_cnnbert")
308
+ trial5_label, trial5_score, trial5_score_list = self.get_preds_from_a_model(
309
+ texts, "sa_trial5_1"
310
+ )
311
+ no_aoa_label, no_aoa_score, no_aoa_score_list = self.get_preds_from_a_model(
312
+ texts, "sa_no_AOA"
313
+ )
314
+ sarcasm_label, sarcasm_score, sarcasm_score_list = self.get_preds_from_a_model(
315
+ texts, "sa_sarcasm"
316
+ )
317
+
318
+ id_label_map = {0: "Positive", 1: "Neutral", 2: "Negative"}
319
+
320
+ final_ensemble_prediction = []
321
+ final_ensemble_score = []
322
+ final_ensemble_all_score = []
323
+ for entry in zip(
324
+ new_balanced_score_list,
325
+ cnn_marbert_score_list,
326
+ trial5_score_list,
327
+ no_aoa_score_list,
328
+ sarcasm_score_list,
329
+ ):
330
+ pos_score = 0
331
+ neu_score = 0
332
+ neg_score = 0
333
+ for s in entry:
334
+ pos_score += s[0] * 1.57
335
+ neu_score += s[1] * 0.98
336
+ neg_score += s[2] * 0.93
337
+
338
+ # weighted 2
339
+ # pos_score += s[0]*1.67
340
+ # neu_score += s[1]
341
+ # neg_score += s[2]*0.95
342
+
343
+ final_ensemble_prediction.append(
344
+ id_label_map[np.argmax([pos_score, neu_score, neg_score])]
345
+ )
346
+ final_ensemble_score.append(np.max([pos_score, neu_score, neg_score]))
347
+ final_ensemble_all_score.append((pos_score, neu_score, neg_score))
348
+
349
+ return final_ensemble_prediction, final_ensemble_score, final_ensemble_all_score
backend/utils.py CHANGED
@@ -1,6 +1,16 @@
1
  import psutil
 
2
 
3
 
4
  def get_current_ram_usage():
5
  ram = psutil.virtual_memory()
6
  return ram.available / 1024 / 1024 / 1024, ram.total / 1024 / 1024 / 1024
 
 
 
 
 
 
 
 
 
 
1
  import psutil
2
+ from huggingface_hub import Repository
3
 
4
 
5
  def get_current_ram_usage():
6
  ram = psutil.virtual_memory()
7
  return ram.available / 1024 / 1024 / 1024, ram.total / 1024 / 1024 / 1024
8
+
9
+
10
+ def download_models(models):
11
+ model_dirs = {}
12
+ for model in models:
13
+ model_dirs[model] = Repository(
14
+ model, clone_from=f"https://huggingface.co/researchaccount/{model}"
15
+ )
16
+ return model_dirs
requirements.txt CHANGED
@@ -7,4 +7,6 @@ emoji==1.4.2
7
  awesome_streamlit
8
  torch==1.9.0
9
  transformers==4.10.0
10
- psutil==5.8.0
 
 
 
7
  awesome_streamlit
8
  torch==1.9.0
9
  transformers==4.10.0
10
+ psutil==5.8.0
11
+ fuzzysearch==0.7.3
12
+ more-itertools==8.9.0