MindfulMedia_Mentor / bm25_retreive_question.py
jaelin215's picture
Upload 14 files
bd9870c verified
raw
history blame
3.48 kB
#---
#- Author: Jaelin Lee
#- Date: Mar 23, 2024
#- Description: Similarity search using BM25. Based on user input, retrieve most relevant info from knowledge base.
#- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base.
#---
from rank_bm25 import BM25Okapi
import nltk
from nltk.tokenize import word_tokenize
# Download NLTK data for tokenization
nltk.download('punkt')
class QuestionRetriever:
def __init__(self):
self.depression_questions = self.load_questions_from_file("src/model_building/RL/data/depression_questions.txt")
self.adhd_questions = self.load_questions_from_file("src/model_building/RL/data/adhd_questions.txt")
self.anxiety_questions = self.load_questions_from_file("src/model_building/RL/data/anxiety_questions.txt")
self.social_isolation_questions = self.load_questions_from_file("src/model_building/RL/data/social_isolation.txt")
self.cyberbullying_questions = self.load_questions_from_file("src/model_building/RL/data/cyberbullying.txt")
self.social_media_addiction_questions = self.load_questions_from_file("src/model_building/RL/data/socialmediaaddiction.txt")
def load_questions_from_file(self, filename):
with open(filename, "r") as file:
questions = file.readlines()
# Remove any leading or trailing whitespace and newline characters
questions = [question.strip() for question in questions]
return questions
def get_response(self, user_query, predicted_mental_category):
if predicted_mental_category == "depression":
knowledge_base = self.depression_questions
elif predicted_mental_category == "adhd":
knowledge_base = self.adhd_questions
elif predicted_mental_category == "anxiety":
knowledge_base = self.anxiety_questions
elif predicted_mental_category == "social isolation":
knowledge_base = self.social_isolation_questions
elif predicted_mental_category == "cyberbullying":
knowledge_base = self.cyberbullying_questions
elif predicted_mental_category == "social media addiction":
knowledge_base = self.social_media_addiction_questions
else:
knowledge_base = None
print("Sorry, I didn't understand that.")
if knowledge_base:
tokenized_docs = [word_tokenize(doc.lower()) for doc in knowledge_base] # Ensure lowercase for consistency
bm25 = BM25Okapi(tokenized_docs)
tokenized_query = word_tokenize(user_query.lower()) # Ensure lowercase for consistency
doc_scores = bm25.get_scores(tokenized_query)
# Get the index of the most relevant document
most_relevant_doc_index = doc_scores.argmax()
# Fetch the corresponding response from the knowledge base
response = knowledge_base[most_relevant_doc_index]
return response
else:
return None
if __name__ == "__main__":
# knowledge_base = "depression_questions"
predicted_mental_category = "cyberbullying"
model = QuestionRetriever()
user_input = input("User: ")
response = model.get_response(user_input, predicted_mental_category)
print("Chatbot:", response)