|
|
|
import os |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain.vectorstores import Chroma |
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain.prompts import PromptTemplate |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
import gradio as gr |
|
import wandb |
|
|
|
|
|
loaders = [] |
|
folder_path = "Data" |
|
for i in range(12): |
|
file_path = os.path.join(folder_path,"{}.txt".format(i)) |
|
loaders.append(TextLoader(file_path)) |
|
docs = [] |
|
for loader in loaders: |
|
docs.extend(loader.load()) |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
embeddings = HuggingFaceInferenceAPIEmbeddings( |
|
api_key=HF_TOKEN, |
|
model_name="sentence-transformers/all-mpnet-base-v2" |
|
) |
|
vectordb = Chroma.from_documents( |
|
documents=docs, |
|
embedding=embeddings |
|
) |
|
llm = HuggingFaceHub( |
|
repo_id="google/gemma-1.1-7b-it", |
|
task="text-generation", |
|
model_kwargs={ |
|
"max_new_tokens": 512, |
|
"top_k": 5, |
|
"temperature": 0.1, |
|
"repetition_penalty": 1.03, |
|
}, |
|
huggingfacehub_api_token=HF_TOKEN |
|
) |
|
template = """ |
|
You are a Mental Health Chatbot. Help the user with their mental health concerns. |
|
Use the context below to answer the questions {context} |
|
Question: {question} |
|
Helpful Answer:""" |
|
|
|
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template) |
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
return_messages=True |
|
) |
|
retriever = vectordb.as_retriever() |
|
qa = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever=retriever, |
|
memory=memory, |
|
) |
|
contextualize_q_system_prompt = """ |
|
Given a chat history and the latest user question |
|
which might reference context in the chat history, |
|
formulate a standalone question |
|
which can be understood without the chat history. |
|
Do NOT answer the question, just reformulate it if needed and otherwise return it as is.""" |
|
contextualize_q_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", contextualize_q_system_prompt), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{question}"), |
|
] |
|
) |
|
contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser() |
|
def contextualized_question(input: dict): |
|
if input.get("chat_history"): |
|
return contextualize_q_chain |
|
else: |
|
return input["question"] |
|
rag_chain = ( |
|
RunnablePassthrough.assign( |
|
context=contextualized_question | retriever |
|
) |
|
| QA_CHAIN_PROMPT |
|
| llm |
|
) |
|
wandb.login(key=os.getenv("key")) |
|
os.environ["LANGCHAIN_WANDB_TRACING"] = "true" |
|
os.environ["WANDB_PROJECT"] = "Mental_Health_ChatBot" |
|
print("Welcome to the Mental Health Chatbot. How can I help you today?") |
|
chat_history = [] |
|
def predict(message, history): |
|
ai_msg = rag_chain.invoke({"question": message, "chat_history": chat_history}) |
|
idx = ai_msg.find("Answer") |
|
chat_history.extend([HumanMessage(content=message), ai_msg]) |
|
return ai_msg[idx:] |
|
gr.ChatInterface(predict).launch() |
|
|