Spaces:
Paused
Paused
import logging | |
import os | |
from langchain.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.document_loaders import DirectoryLoader | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from dotenv import load_dotenv | |
from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
from langchain.llms import OpenAI | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.chains import LLMChain, HypotheticalDocumentEmbedder | |
## Setting up Log configuration | |
logging.basicConfig( | |
filename='Logs/chatbot.log', # Name of the log file | |
level=logging.INFO, # Logging level (you can use logging.DEBUG for more detailed logs) | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
class Jine: | |
def __init__(self, OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG,USE_HYDE=False): | |
self.OPENAI_API_KEY = OPENAI_API_KEY | |
self.DATA_DIRECTORY = DATA_DIRECTORY | |
self.VECTOR_STORE_DIRECTORY = VECTOR_STORE_DIRECTORY | |
self.VECTOR_STORE_CHECK = VECTOR_STORE_CHECK | |
self.DEBUG = DEBUG | |
self.vectorstore = None | |
self.bot = None | |
self.USE_HYDE = USE_HYDE | |
# creating this variable for BM25 Retriver. | |
# self.docs = None | |
def create_vectorstore(self): | |
if self.VECTOR_STORE_CHECK: | |
print("Loading Vectorstore") | |
self.load_vectorstore() | |
print('im running') | |
else: | |
print("Creating Vectorstore") | |
docs = DirectoryLoader(self.DATA_DIRECTORY).load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10) | |
all_splits = text_splitter.split_documents(docs) | |
if self.USE_HYDE: | |
base_embeddings = OpenAIEmbeddings() | |
llm = OpenAI() | |
embeddings_hyde = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings, "web_search") | |
self.vectorstore = Chroma.from_documents(documents=all_splits, embedding=embeddings_hyde, | |
persist_directory=self.VECTOR_STORE_DIRECTORY) | |
else: | |
self.vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings(), | |
persist_directory=self.VECTOR_STORE_DIRECTORY) | |
def multi_query_retriever(self): | |
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(), | |
llm=ChatOpenAI(temperature=0)) | |
template = """Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that "i am unable to answer your query, for more information contact your HRBP", don't try to make up an answer. | |
Use three sentences maximum and keep the answer as concise as possible. | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
self.bot = RetrievalQA.from_chain_type( | |
llm, | |
retriever=retriever_from_llm, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} | |
) | |
def single_query_retriever(self): | |
template = """Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that "i am unable to answer your query, for more information contact your HRBP", don't try to make up an answer. | |
Use three sentences maximum and keep the answer as concise as possible. | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
self.bot = RetrievalQA.from_chain_type( | |
llm, | |
retriever=self.vectorstore.as_retriever(), | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}) | |
def load_vectorstore(self): | |
if self.USE_HYDE: | |
print("Using HYDE embeddings vectorstore") | |
base_embeddings = OpenAIEmbeddings() | |
llm = OpenAI() | |
embeddings_hyde = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings, "web_search") | |
self.vectorstore = Chroma(persist_directory=self.VECTOR_STORE_DIRECTORY, embedding_function=embeddings_hyde) | |
else: | |
print("Using Simple embeddings vectorstore") | |
self.vectorstore = Chroma(persist_directory=self.VECTOR_STORE_DIRECTORY, embedding_function=OpenAIEmbeddings()) | |
def log(self, user_question, chatbot_reply): | |
# Log the user's question | |
logging.info(f"User: {user_question}") | |
# Log the chatbot's reply | |
logging.info(f"JIN-e: {chatbot_reply}") | |
def load_model(self): | |
self.create_vectorstore() | |
# self.multi_query_retriever() | |
# self.single_query_retriever() | |
self.create_ensemble_retriever() | |
def chat(self, user_question): | |
result = self.bot({"query": user_question}) | |
response = result["result"] | |
self.log(user_question, response) | |
return response | |
### Adding Ensemble retriver | |
def create_ensemble_retriever(self): | |
template = """Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that "i am unable to answer your query, for more information contact your HRBP", don't try to make up an answer. | |
Use three sentences maximum and keep the answer as concise as possible. | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
print("====================="*10) | |
print("Loading Documents for Ensemble Retriver") | |
print("====================="*10) | |
docs = DirectoryLoader(self.DATA_DIRECTORY).load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10) | |
all_splits = text_splitter.split_documents(docs) | |
bm25_retriever = BM25Retriever.from_documents(all_splits) | |
# GEttting only two relevant documents | |
bm25_retriever.k = 2 | |
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, | |
self.vectorstore.as_retriever(search_kwargs={"k": 2})], | |
weights=[0.5, 0.5]) | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo-16k", temperature=0) | |
self.bot = RetrievalQA.from_chain_type( | |
llm, | |
retriever=ensemble_retriever, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}) | |
if __name__ == "__main__": | |
# Set your configuration here | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
DATA_DIRECTORY = os.getenv("DATA_DIRECTORY") | |
VECTOR_STORE_DIRECTORY = os.getenv("VECTOR_STORE_DIRCTORY") | |
VECTOR_STORE_CHECK = os.getenv("VECTOR_STORE_CHECK") | |
DEBUG = os.getenv("DEBUG") | |
USE_HYDE = os.getenv("USE_HYDE") | |
# Initialize Jine and start chatting | |
jine = Jine(OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG) | |
# print(jine.VECTOR_STORE_CHECK) | |
jine.load_model() | |
while True: | |
user_question = input("You: ") | |
if user_question.lower() in ["exit", "quit"]: | |
break | |
response = jine.chat(user_question) | |
print("JIN-e:", response) | |