Spaces:
Sleeping
Sleeping
import nltk | |
from nltk.tokenize import word_tokenize | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain_community.vectorstores import Chroma | |
from langchain_text_splitters import CharacterTextSplitter | |
# Download NLTK data for tokenization | |
nltk.download('punkt') | |
import os | |
global db | |
class QuestionRetriever: | |
def load_documents(self,file_name): | |
data_directory = "data/" | |
file_path = os.path.join(data_directory, file_name) | |
loader = TextLoader(file_path) | |
documents = loader.load() | |
return documents | |
def store_data_in_vector_db(self,documents): | |
# global db | |
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0,separator="\n") | |
docs = text_splitter.split_documents(documents) | |
# create the open-source embedding function | |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# print(docs) | |
# load it into Chroma | |
db = Chroma.from_documents(docs, embedding_function) | |
return db | |
def get_response(self, user_query, predicted_mental_category): | |
if predicted_mental_category == "depression": | |
documents=self.load_documents("depression_questions.txt") | |
elif predicted_mental_category == "adhd": | |
documents=self.load_documents("adhd_questions.txt") | |
elif predicted_mental_category == "anxiety": | |
documents=self.load_documents("anxiety_questions.txt") | |
else: | |
print("Sorry, allowed predicted_mental_category is ['depresison', 'adhd', 'anxiety'].") | |
return | |
db=self.store_data_in_vector_db(documents) | |
docs = db.similarity_search(user_query) | |
most_similar_question = docs[0].page_content.split("\n")[0] # Extract the first question | |
if user_query==most_similar_question: | |
most_similar_question=docs[1].page_content.split("\n")[0] | |
print(most_similar_question) | |
return most_similar_question | |
if __name__ == "__main__": | |
model = QuestionRetriever() | |
user_input = input("User: ") | |
predicted_mental_condition = "depression" | |
response = model.get_response(user_input, predicted_mental_condition) | |
print("Chatbot:", response) | |