a
File size: 3,300 Bytes
8f87c57
ca9d22c
 
 
 
 
 
 
 
 
 
 
 
 
8f87c57
ca9d22c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdd95b8
ca9d22c
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

import os
from langchain_community.document_loaders import TextLoader
from langchain.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
import wandb

# Initialize the chatbot
loaders = []
folder_path = "Data"
for i in range(12):
    file_path = os.path.join(folder_path,"{}.txt".format(i))
    loaders.append(TextLoader(file_path))
docs = []
for loader in loaders:
    docs.extend(loader.load())
HF_TOKEN = os.getenv("HF_TOKEN")
embeddings = HuggingFaceInferenceAPIEmbeddings(
    api_key=HF_TOKEN,
    model_name="sentence-transformers/all-mpnet-base-v2"
)
vectordb = Chroma.from_documents(
    documents=docs,
    embedding=embeddings
)
llm = HuggingFaceHub(
    repo_id="google/gemma-1.1-7b-it",
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "top_k": 5,
        "temperature": 0.1,
        "repetition_penalty": 1.03,
    },
    huggingfacehub_api_token=HF_TOKEN
)
template = """
You are a Mental Health Chatbot. Help the user with their mental health concerns. 
Use the context below to answer the questions {context} 
Question: {question} 
Helpful Answer:"""
        
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template)
memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True
)
retriever = vectordb.as_retriever()
qa = ConversationalRetrievalChain.from_llm(
    llm,
    retriever=retriever,
    memory=memory,
)
contextualize_q_system_prompt = """
Given a chat history and the latest user question 
which might reference context in the chat history, 
formulate a standalone question 
which can be understood without the chat history. 
Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ]
)
contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
def contextualized_question(input: dict):
    if input.get("chat_history"):
        return contextualize_q_chain
    else:
        return input["question"]
rag_chain = (
    RunnablePassthrough.assign(
        context=contextualized_question | retriever
    )
    | QA_CHAIN_PROMPT
    | llm
)
wandb.login(key=os.getenv("key"))
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
os.environ["WANDB_PROJECT"] = "Mental_Health_ChatBot"
print("Welcome to the Mental Health Chatbot. How can I help you today?")
chat_history = []
def predict(message, history):
    ai_msg = rag_chain.invoke({"question": message, "chat_history": chat_history})
    idx = ai_msg.find("Answer")
    chat_history.extend([HumanMessage(content=message), ai_msg])
    return ai_msg[idx:]
gr.ChatInterface(predict).launch()