File size: 3,300 Bytes
8f87c57 ca9d22c 8f87c57 ca9d22c bdd95b8 ca9d22c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
from langchain_community.document_loaders import TextLoader
from langchain.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
import wandb
# Initialize the chatbot
loaders = []
folder_path = "Data"
for i in range(12):
file_path = os.path.join(folder_path,"{}.txt".format(i))
loaders.append(TextLoader(file_path))
docs = []
for loader in loaders:
docs.extend(loader.load())
HF_TOKEN = os.getenv("HF_TOKEN")
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=HF_TOKEN,
model_name="sentence-transformers/all-mpnet-base-v2"
)
vectordb = Chroma.from_documents(
documents=docs,
embedding=embeddings
)
llm = HuggingFaceHub(
repo_id="google/gemma-1.1-7b-it",
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"top_k": 5,
"temperature": 0.1,
"repetition_penalty": 1.03,
},
huggingfacehub_api_token=HF_TOKEN
)
template = """
You are a Mental Health Chatbot. Help the user with their mental health concerns.
Use the context below to answer the questions {context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
retriever = vectordb.as_retriever()
qa = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
memory=memory,
)
contextualize_q_system_prompt = """
Given a chat history and the latest user question
which might reference context in the chat history,
formulate a standalone question
which can be understood without the chat history.
Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
def contextualized_question(input: dict):
if input.get("chat_history"):
return contextualize_q_chain
else:
return input["question"]
rag_chain = (
RunnablePassthrough.assign(
context=contextualized_question | retriever
)
| QA_CHAIN_PROMPT
| llm
)
wandb.login(key=os.getenv("key"))
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
os.environ["WANDB_PROJECT"] = "Mental_Health_ChatBot"
print("Welcome to the Mental Health Chatbot. How can I help you today?")
chat_history = []
def predict(message, history):
ai_msg = rag_chain.invoke({"question": message, "chat_history": chat_history})
idx = ai_msg.find("Answer")
chat_history.extend([HumanMessage(content=message), ai_msg])
return ai_msg[idx:]
gr.ChatInterface(predict).launch()
|