|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_community.embeddings.sentence_transformer import ( |
|
SentenceTransformerEmbeddings, |
|
) |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_text_splitters import CharacterTextSplitter |
|
|
|
|
|
nltk.download('punkt') |
|
import os |
|
global db |
|
class QuestionRetriever: |
|
|
|
def load_documents(self,file_name): |
|
current_directory = os.getcwd() |
|
data_directory = os.path.join(current_directory, "data") |
|
file_path = os.path.join(data_directory, file_name) |
|
loader = TextLoader(file_path) |
|
documents = loader.load() |
|
return documents |
|
|
|
def store_data_in_vector_db(self,documents): |
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0,separator="\n") |
|
docs = text_splitter.split_documents(documents) |
|
|
|
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
db = Chroma.from_documents(docs, embedding_function) |
|
return db |
|
|
|
def get_response(self, user_query, predicted_mental_category): |
|
if predicted_mental_category == "depression": |
|
documents=self.load_documents("depression_questions.txt") |
|
|
|
elif predicted_mental_category == "adhd": |
|
documents=self.load_documents("adhd_questions.txt") |
|
|
|
elif predicted_mental_category == "anxiety": |
|
documents=self.load_documents("anxiety_questions.txt") |
|
|
|
else: |
|
print("Sorry, allowed predicted_mental_category is ['depresison', 'adhd', 'anxiety'].") |
|
return |
|
db=self.store_data_in_vector_db(documents) |
|
|
|
docs = db.similarity_search(user_query) |
|
most_similar_question = docs[0].page_content.split("\n")[0] |
|
if user_query==most_similar_question: |
|
most_similar_question=docs[1].page_content.split("\n")[0] |
|
|
|
print(most_similar_question) |
|
return most_similar_question |
|
|
|
if __name__ == "__main__": |
|
model = QuestionRetriever() |
|
user_input = input("User: ") |
|
|
|
predicted_mental_condition = "depression" |
|
response = model.get_response(user_input, predicted_mental_condition) |
|
print("Chatbot:", response) |