Spaces:
Running
Running
wissamantoun
commited on
Commit
•
c6d8cfb
1
Parent(s):
7903030
removed ensembling in SA (RAM issues)
Browse files- backend/services.py +35 -35
backend/services.py
CHANGED
@@ -203,30 +203,30 @@ class SentimentAnalyzer:
|
|
203 |
def __init__(self):
|
204 |
self.sa_models = [
|
205 |
"sa_trial5_1",
|
206 |
-
"sa_no_aoa_in_neutral",
|
207 |
-
"sa_cnnbert",
|
208 |
-
"sa_sarcasm",
|
209 |
-
"sar_trial10",
|
210 |
-
"sa_no_AOA",
|
211 |
]
|
212 |
download_models(self.sa_models)
|
213 |
# fmt: off
|
214 |
self.processors = {
|
215 |
"sa_trial5_1": Trial5ArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
216 |
-
"sa_no_aoa_in_neutral": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
|
217 |
-
"sa_cnnbert": CNNMarbertArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
218 |
-
"sa_sarcasm": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
219 |
-
"sar_trial10": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
220 |
-
"sa_no_AOA": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
|
221 |
}
|
222 |
|
223 |
self.pipelines = {
|
224 |
"sa_trial5_1": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_trial5_1",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_trial5_1")],
|
225 |
-
"sa_no_aoa_in_neutral": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_aoa_in_neutral",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_aoa_in_neutral")],
|
226 |
-
"sa_cnnbert": [CNNTextClassificationPipeline("{}/train_{}/best_model".format("sa_cnnbert",i), device=-1, return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_cnnbert")],
|
227 |
-
"sa_sarcasm": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_sarcasm",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_sarcasm")],
|
228 |
-
"sar_trial10": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sar_trial10",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sar_trial10")],
|
229 |
-
"sa_no_AOA": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_AOA",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_AOA")],
|
230 |
}
|
231 |
# fmt: on
|
232 |
|
@@ -324,25 +324,25 @@ class SentimentAnalyzer:
|
|
324 |
|
325 |
def predict(self, texts: List[str]):
|
326 |
logger.info(f"Predicting for: {texts}")
|
327 |
-
(
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
) = self.get_preds_from_a_model(texts, "sa_no_aoa_in_neutral")
|
332 |
-
(
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
) = self.get_preds_from_a_model(texts, "sa_cnnbert")
|
337 |
trial5_label, trial5_score, trial5_score_list = self.get_preds_from_a_model(
|
338 |
texts, "sa_trial5_1"
|
339 |
)
|
340 |
-
no_aoa_label, no_aoa_score, no_aoa_score_list = self.get_preds_from_a_model(
|
341 |
-
|
342 |
-
)
|
343 |
-
sarcasm_label, sarcasm_score, sarcasm_score_list = self.get_preds_from_a_model(
|
344 |
-
|
345 |
-
)
|
346 |
|
347 |
id_label_map = {0: "Positive", 1: "Neutral", 2: "Negative"}
|
348 |
|
@@ -350,11 +350,11 @@ class SentimentAnalyzer:
|
|
350 |
final_ensemble_score = []
|
351 |
final_ensemble_all_score = []
|
352 |
for entry in zip(
|
353 |
-
new_balanced_score_list,
|
354 |
-
cnn_marbert_score_list,
|
355 |
trial5_score_list,
|
356 |
-
no_aoa_score_list,
|
357 |
-
sarcasm_score_list,
|
358 |
):
|
359 |
pos_score = 0
|
360 |
neu_score = 0
|
|
|
203 |
def __init__(self):
|
204 |
self.sa_models = [
|
205 |
"sa_trial5_1",
|
206 |
+
# "sa_no_aoa_in_neutral",
|
207 |
+
# "sa_cnnbert",
|
208 |
+
# "sa_sarcasm",
|
209 |
+
# "sar_trial10",
|
210 |
+
# "sa_no_AOA",
|
211 |
]
|
212 |
download_models(self.sa_models)
|
213 |
# fmt: off
|
214 |
self.processors = {
|
215 |
"sa_trial5_1": Trial5ArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
216 |
+
# "sa_no_aoa_in_neutral": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
|
217 |
+
# "sa_cnnbert": CNNMarbertArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
218 |
+
# "sa_sarcasm": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
219 |
+
# "sar_trial10": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
|
220 |
+
# "sa_no_AOA": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
|
221 |
}
|
222 |
|
223 |
self.pipelines = {
|
224 |
"sa_trial5_1": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_trial5_1",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_trial5_1")],
|
225 |
+
# "sa_no_aoa_in_neutral": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_aoa_in_neutral",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_aoa_in_neutral")],
|
226 |
+
# "sa_cnnbert": [CNNTextClassificationPipeline("{}/train_{}/best_model".format("sa_cnnbert",i), device=-1, return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_cnnbert")],
|
227 |
+
# "sa_sarcasm": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_sarcasm",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_sarcasm")],
|
228 |
+
# "sar_trial10": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sar_trial10",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sar_trial10")],
|
229 |
+
# "sa_no_AOA": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_AOA",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_AOA")],
|
230 |
}
|
231 |
# fmt: on
|
232 |
|
|
|
324 |
|
325 |
def predict(self, texts: List[str]):
|
326 |
logger.info(f"Predicting for: {texts}")
|
327 |
+
# (
|
328 |
+
# new_balanced_label,
|
329 |
+
# new_balanced_score,
|
330 |
+
# new_balanced_score_list,
|
331 |
+
# ) = self.get_preds_from_a_model(texts, "sa_no_aoa_in_neutral")
|
332 |
+
# (
|
333 |
+
# cnn_marbert_label,
|
334 |
+
# cnn_marbert_score,
|
335 |
+
# cnn_marbert_score_list,
|
336 |
+
# ) = self.get_preds_from_a_model(texts, "sa_cnnbert")
|
337 |
trial5_label, trial5_score, trial5_score_list = self.get_preds_from_a_model(
|
338 |
texts, "sa_trial5_1"
|
339 |
)
|
340 |
+
# no_aoa_label, no_aoa_score, no_aoa_score_list = self.get_preds_from_a_model(
|
341 |
+
# texts, "sa_no_AOA"
|
342 |
+
# )
|
343 |
+
# sarcasm_label, sarcasm_score, sarcasm_score_list = self.get_preds_from_a_model(
|
344 |
+
# texts, "sa_sarcasm"
|
345 |
+
# )
|
346 |
|
347 |
id_label_map = {0: "Positive", 1: "Neutral", 2: "Negative"}
|
348 |
|
|
|
350 |
final_ensemble_score = []
|
351 |
final_ensemble_all_score = []
|
352 |
for entry in zip(
|
353 |
+
# new_balanced_score_list,
|
354 |
+
# cnn_marbert_score_list,
|
355 |
trial5_score_list,
|
356 |
+
# no_aoa_score_list,
|
357 |
+
# sarcasm_score_list,
|
358 |
):
|
359 |
pos_score = 0
|
360 |
neu_score = 0
|