jin-e / jine_v1.py
hamxahbhattii's picture
added Jine
6330947
import logging
import os
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from dotenv import load_dotenv
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.llms import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import LLMChain, HypotheticalDocumentEmbedder
## Setting up Log configuration
logging.basicConfig(
filename='Logs/chatbot.log', # Name of the log file
level=logging.INFO, # Logging level (you can use logging.DEBUG for more detailed logs)
format='%(asctime)s - %(levelname)s - %(message)s'
)
class Jine:
def __init__(self, OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG,USE_HYDE=False):
self.OPENAI_API_KEY = OPENAI_API_KEY
self.DATA_DIRECTORY = DATA_DIRECTORY
self.VECTOR_STORE_DIRECTORY = VECTOR_STORE_DIRECTORY
self.VECTOR_STORE_CHECK = VECTOR_STORE_CHECK
self.DEBUG = DEBUG
self.vectorstore = None
self.bot = None
self.USE_HYDE = USE_HYDE
# creating this variable for BM25 Retriver.
# self.docs = None
def create_vectorstore(self):
if self.VECTOR_STORE_CHECK:
print("Loading Vectorstore")
self.load_vectorstore()
print('im running')
else:
print("Creating Vectorstore")
docs = DirectoryLoader(self.DATA_DIRECTORY).load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10)
all_splits = text_splitter.split_documents(docs)
if self.USE_HYDE:
base_embeddings = OpenAIEmbeddings()
llm = OpenAI()
embeddings_hyde = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings, "web_search")
self.vectorstore = Chroma.from_documents(documents=all_splits, embedding=embeddings_hyde,
persist_directory=self.VECTOR_STORE_DIRECTORY)
else:
self.vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings(),
persist_directory=self.VECTOR_STORE_DIRECTORY)
def multi_query_retriever(self):
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(),
llm=ChatOpenAI(temperature=0))
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that "i am unable to answer your query, for more information contact your HRBP", don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
self.bot = RetrievalQA.from_chain_type(
llm,
retriever=retriever_from_llm,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
def single_query_retriever(self):
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that "i am unable to answer your query, for more information contact your HRBP", don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
self.bot = RetrievalQA.from_chain_type(
llm,
retriever=self.vectorstore.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
def load_vectorstore(self):
if self.USE_HYDE:
print("Using HYDE embeddings vectorstore")
base_embeddings = OpenAIEmbeddings()
llm = OpenAI()
embeddings_hyde = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings, "web_search")
self.vectorstore = Chroma(persist_directory=self.VECTOR_STORE_DIRECTORY, embedding_function=embeddings_hyde)
else:
print("Using Simple embeddings vectorstore")
self.vectorstore = Chroma(persist_directory=self.VECTOR_STORE_DIRECTORY, embedding_function=OpenAIEmbeddings())
def log(self, user_question, chatbot_reply):
# Log the user's question
logging.info(f"User: {user_question}")
# Log the chatbot's reply
logging.info(f"JIN-e: {chatbot_reply}")
def load_model(self):
self.create_vectorstore()
# self.multi_query_retriever()
# self.single_query_retriever()
self.create_ensemble_retriever()
def chat(self, user_question):
result = self.bot({"query": user_question})
response = result["result"]
self.log(user_question, response)
return response
### Adding Ensemble retriver
def create_ensemble_retriever(self):
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that "i am unable to answer your query, for more information contact your HRBP", don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
print("====================="*10)
print("Loading Documents for Ensemble Retriver")
print("====================="*10)
docs = DirectoryLoader(self.DATA_DIRECTORY).load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10)
all_splits = text_splitter.split_documents(docs)
bm25_retriever = BM25Retriever.from_documents(all_splits)
# GEttting only two relevant documents
bm25_retriever.k = 2
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever,
self.vectorstore.as_retriever(search_kwargs={"k": 2})],
weights=[0.5, 0.5])
llm = ChatOpenAI(model_name="gpt-3.5-turbo-16k", temperature=0)
self.bot = RetrievalQA.from_chain_type(
llm,
retriever=ensemble_retriever,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
if __name__ == "__main__":
# Set your configuration here
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
DATA_DIRECTORY = os.getenv("DATA_DIRECTORY")
VECTOR_STORE_DIRECTORY = os.getenv("VECTOR_STORE_DIRCTORY")
VECTOR_STORE_CHECK = os.getenv("VECTOR_STORE_CHECK")
DEBUG = os.getenv("DEBUG")
USE_HYDE = os.getenv("USE_HYDE")
# Initialize Jine and start chatting
jine = Jine(OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG)
# print(jine.VECTOR_STORE_CHECK)
jine.load_model()
while True:
user_question = input("You: ")
if user_question.lower() in ["exit", "quit"]:
break
response = jine.chat(user_question)
print("JIN-e:", response)