import torch def load_words_from_file(file_path): """ Load words from a text file and return them as a list. Each word should be on a separate line in the text file. """ with open(file_path, 'r', encoding='utf-8') as file: words = file.read().splitlines() return words def preprocess_with_negation_v2(text): from emotion_utils import load_words_from_file negation_words = load_words_from_file('./model/stopwords/negation_words.txt') emotion_words = load_words_from_file('./model/stopwords/emotion_words.txt') # Tokenize the sentence into words words = text.split() modified_words = words[:] # Create a copy to modify # Iterate through all words to detect negation-emotion pairs for i, word in enumerate(words): if word in negation_words: # Check the previous 3 words for an emotion word for j in range(1, 4): if i - j >= 0 and words[i - j] in emotion_words: # Mark the detected emotion with a negation label modified_words[i - j] = f"{words[i - j]} (Negative context)" break # Reconstruct the text return " ".join(modified_words) def predict(text, model, tokenizer): from emotion_utils import preprocess_with_negation_v2 """ Predict the sentiment for a given text with advanced negation handling. """ # Ensure the model is on the correct device device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) # Preprocess the text for advanced negation handling processed_text = preprocess_with_negation_v2(text) # print(processed_text) # Tokenize the text inputs = tokenizer( processed_text, padding=True, truncation=True, max_length=512, return_tensors="pt" ).to(device) # Perform inference with torch.no_grad(): outputs = model(**inputs) # Compute probabilities probs = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get the class with the highest probability pred_label_idx = probs.argmax(dim=-1).item() # Map the index to the label pred_label = model.config.id2label[pred_label_idx] # Adjust prediction for negation context negation_map = { "Sadness": "Optimistic", "Optimistic": "Sadness", "Anger": "Optimistic", } if "(Negative context)" in processed_text: pred_label = negation_map.get(pred_label, pred_label) return probs, pred_label_idx, pred_label