Spaces:
Sleeping
Sleeping
from itertools import chain | |
from random import choice | |
from typing import Any, Dict, List, Optional, Tuple | |
from datasets import Dataset | |
def adjust_predictions(refs, preds, choices): | |
"""Adjust predictions to match the length of references with either a special token or random choice.""" | |
adjusted_preds = [] | |
for ref, pred in zip(refs, preds): | |
if len(pred) < len(ref): | |
missing_count = len(ref) - len(pred) | |
pred.extend([choice(choices) for _ in range(missing_count)]) | |
adjusted_preds.append(pred) | |
return adjusted_preds | |
def extract_aspects(data, specific_key, specific_val): | |
"""Extracts and returns a list of specified aspect details from the nested 'aspects' data.""" | |
return [item[specific_key][specific_val] for item in data] | |
def absa_term_preprocess(references, predictions, subtask_key, subtask_value): | |
""" | |
Preprocess the terms and polarities for aspect-based sentiment analysis. | |
Args: | |
references (List[Dict]): A list of dictionaries containing the actual terms and polarities under 'aspects'. | |
predictions (List[Dict]): A list of dictionaries containing predicted aspect categories to terms and their sentiments. | |
Returns: | |
Tuple[List[str], List[str], List[str], List[str]]: A tuple containing lists of true aspect terms, | |
adjusted predicted aspect terms, true polarities, and adjusted predicted polarities. | |
""" | |
# Extract aspect terms and polarities | |
truth_aspect_terms = extract_aspects(references, subtask_key, subtask_value) | |
pred_aspect_terms = extract_aspects(predictions, subtask_key, subtask_value) | |
truth_polarities = extract_aspects(references, subtask_key, "polarity") | |
pred_polarities = extract_aspects(predictions, subtask_key, "polarity") | |
# Define adjustment parameters | |
special_token = "NONE" # For missing aspect terms | |
sentiment_choices = [ | |
"positive", | |
"negative", | |
"neutral", | |
"conflict", | |
] # For missing polarities | |
# Adjust the predictions to match the length of references | |
adjusted_pred_terms = adjust_predictions( | |
truth_aspect_terms, pred_aspect_terms, [special_token] | |
) | |
adjusted_pred_polarities = adjust_predictions( | |
truth_polarities, pred_polarities, sentiment_choices | |
) | |
return ( | |
flatten_list(truth_aspect_terms), | |
flatten_list(adjusted_pred_terms), | |
flatten_list(truth_polarities), | |
flatten_list(adjusted_pred_polarities), | |
) | |
def flatten_list(nested_list): | |
"""Flatten a nested list into a single-level list.""" | |
return list(chain.from_iterable(nested_list)) | |
def extract_pred_terms( | |
all_predictions: List[Dict[str, Dict[str, str]]] | |
) -> List[List]: | |
"""Extract and organize predicted terms from the sentiment analysis results.""" | |
pred_aspect_terms = [] | |
for pred in all_predictions: | |
terms = [term for cat in pred.values() for term in cat.keys()] | |
pred_aspect_terms.append(terms) | |
return pred_aspect_terms | |
def merge_aspects_and_categories(aspects, categories): | |
result = [] | |
# Assuming both lists are of the same length and corresponding indices match | |
for aspect, category in zip(aspects, categories): | |
combined_entry = { | |
"aspects": {"term": [], "polarity": []}, | |
"category": {"category": [], "polarity": []}, | |
} | |
# Process aspect entries | |
for cat_key, terms_dict in aspect.items(): | |
for term, polarity in terms_dict.items(): | |
combined_entry["aspects"]["term"].append(term) | |
combined_entry["aspects"]["polarity"].append(polarity) | |
# Add category details based on the aspect's key if available in categories | |
if cat_key in category: | |
combined_entry["category"]["category"].append(cat_key) | |
combined_entry["category"]["polarity"].append( | |
category[cat_key] | |
) | |
# Ensure all keys in category are accounted for | |
for cat_key, polarity in category.items(): | |
if cat_key not in combined_entry["category"]["category"]: | |
combined_entry["category"]["category"].append(cat_key) | |
combined_entry["category"]["polarity"].append(polarity) | |
result.append(combined_entry) | |
return result | |