|
import subprocess |
|
|
|
import streamlit as st |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma, FAISS |
|
from langchain.embeddings import FastEmbedEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.callbacks.manager import CallbackManager |
|
from langchain.callbacks import StreamlitCallbackHandler |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from htmlTemplates import css, bot_template, user_template |
|
from langchain.llms import LlamaCpp, OpenAI, GooglePalm |
|
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader |
|
from langchain.chains import RetrievalQA |
|
from langchain.prompts import PromptTemplate |
|
from langchain import hub |
|
import tempfile |
|
import os |
|
import glob |
|
import shutil |
|
import time |
|
|
|
|
|
def get_pdf_text(pdf_docs): |
|
""" |
|
Purpose: A hypothetical loader for PDF files in Python. |
|
Usage: Used to extract text or other information from PDF documents. |
|
Load Function: A load_pdf function might be used to read and extract data from a PDF file. |
|
|
|
input : pdf document path |
|
returns : extracted text |
|
""" |
|
temp_dir = tempfile.TemporaryDirectory() |
|
temp_filepath = os.path.join(temp_dir.name, pdf_docs.name) |
|
|
|
with open(temp_filepath, "wb") as f: |
|
f.write(pdf_docs.getvalue()) |
|
|
|
pdf_loader = PyPDFLoader(temp_filepath) |
|
pdf_doc = pdf_loader.load() |
|
return pdf_doc |
|
|
|
|
|
def get_text_file(text_docs): |
|
""" |
|
""" |
|
temp_dir = tempfile.TemporaryDirectory() |
|
temp_filepath = os.path.join(temp_dir.name, text_docs.name) |
|
|
|
with open(temp_filepath, "wb") as f: |
|
f.write(text_docs.getvalue()) |
|
|
|
text_loader = TextLoader(temp_filepath) |
|
text_doc = text_loader.load() |
|
return text_doc |
|
|
|
def get_csv_file(csv_docs): |
|
temp_dir = tempfile.TemporaryDirectory() |
|
temp_filepath = os.path.join(temp_dir.name, csv_docs.name) |
|
|
|
with open(temp_filepath, "wb") as f: |
|
f.write(csv_docs.getvalue()) |
|
|
|
csv_loader = CSVLoader(temp_filepath) |
|
csv_doc = csv_loader.load() |
|
return csv_doc |
|
|
|
|
|
|
|
def get_text_chunks(documents): |
|
""" |
|
For the compute purpose we will split the document into multiple smaller chunks. |
|
|
|
IMPORTANT : If the chunks too small we will miss the context and if its too large we will have longer compute time |
|
""" |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size= 1000, |
|
chunk_overlap=200, |
|
) |
|
|
|
text_chunks = text_splitter.split_documents(documents) |
|
|
|
return text_chunks |
|
|
|
|
|
|
|
def get_vectorstore(text_chunks): |
|
""" |
|
Load our vectors into chroma DB, Googles Vector Store |
|
""" |
|
vectorstore = Chroma.from_documents(documents= text_chunks, |
|
embedding= st.session_state.embeddings, |
|
persist_directory= "./vectordb/") |
|
|
|
return vectorstore |
|
|
|
|
|
|
|
def get_conversation_chain(vectorstore): |
|
""" |
|
This is a langchain model where we will be binding the runner to infer data from LLM |
|
""" |
|
model_path = st.session_state.model |
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
|
|
if st.session_state.model == "Google_PaLm" : |
|
llm = GooglePalm(google_api_key = "Add your google palm API", |
|
max_output_tokens = 4000, |
|
callback_manager=callback_manager) |
|
|
|
elif st.session_state.model == "Open_AIGPT-3.5-Turbo": |
|
llm = OpenAI(api_key = "add your openAI Key", |
|
callback_manager = callback_manager, |
|
max_tokens= 4000 ) |
|
|
|
else: |
|
llm = LlamaCpp(model_path= model_path, |
|
n_ctx= 4000, |
|
max_tokens= 4000, |
|
f16_kv = True, |
|
callback_manager = callback_manager, |
|
verbose=True) |
|
|
|
prompt_template = """You are a personal HR Bot assistant for answering any questions about Companies policies |
|
You are given a question and a set of documents. |
|
If the user's question requires you to provide specific information from the documents, give your answer based only on the examples provided below. DON'T generate an answer that is NOT written in the provided examples. |
|
If you don't find the answer to the user's question with the examples provided to you below, answer that you didn't find the answer in the documentation and propose him to rephrase his query with more details. |
|
Use bullet points if you have to make a list, only if necessary. Use 'DOCUMENTS' as a reference point, to understand and give a consciese output in 3 or 5 sentences. |
|
|
|
QUESTION: {question} |
|
|
|
DOCUMENTS: |
|
========= |
|
{context} |
|
========= |
|
Finish by proposing your help for anything else. |
|
""" |
|
|
|
rag_prompt_custom = PromptTemplate.from_template(prompt_template) |
|
|
|
|
|
prompt = hub.pull("rlm/rag-prompt-mistral") |
|
|
|
conversation_chain = RetrievalQA.from_chain_type( |
|
llm, |
|
retriever= vectorstore.as_retriever(), |
|
chain_type_kwargs={"prompt": prompt}, |
|
) |
|
conversation_chain.callback_manager = callback_manager |
|
conversation_chain.memory = ConversationBufferMemory() |
|
|
|
return conversation_chain |
|
|
|
|
|
def handle_userinput(): |
|
|
|
clear = False |
|
|
|
|
|
if st.button("Clear Chat history"): |
|
clear = True |
|
st.session_state.messages = [] |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}] |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if clear: |
|
st.session_state.conversation.memory.clear() |
|
clear = False |
|
|
|
if prompt := st.chat_input(): |
|
|
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
st.session_state.messages.append( {"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
|
st_callback = StreamlitCallbackHandler(st.container()) |
|
message_holder = st.empty() |
|
full_response = "" |
|
|
|
|
|
st.session_state.conversation.callback_manager = st_callback |
|
msg = st.session_state.conversation.run(prompt) |
|
|
|
for chunk in msg.split(): |
|
full_response += chunk + " " |
|
time.sleep(0.09) |
|
|
|
|
|
message_holder.markdown(full_response + "βοΈ ") |
|
|
|
|
|
message_holder.info(full_response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
|
|
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30): |
|
st.markdown( |
|
f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>', |
|
unsafe_allow_html=True,) |
|
st.image(image_path, use_column_width=True, output_format='auto') |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(page_title="RANDSTAD", |
|
page_icon=":books:") |
|
st.write(css, unsafe_allow_html=True) |
|
|
|
|
|
if "conversation" not in st.session_state: |
|
st.session_state.conversation = None |
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = None |
|
|
|
st.title("π¬ Randstad HR Chatbot") |
|
st.subheader("π A HR powered by Generative AI") |
|
|
|
|
|
st.session_state.model = "Google_PaLm" |
|
|
|
|
|
st.session_state.embeddings = FastEmbedEmbeddings( model_name= "BAAI/bge-base-en-v1.5", cache_dir="./embedding_model/") |
|
|
|
if len(glob.glob("./vectordb/*.sqlite3")) > 0 : |
|
|
|
vectorstore = Chroma(persist_directory="./vectordb/", embedding_function=st.session_state.embeddings) |
|
st.session_state.conversation = get_conversation_chain(vectorstore) |
|
handle_userinput() |
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
add_rounded_edges() |
|
|
|
st.subheader("Select Your Embedding Model Model") |
|
LLM = list( glob.glob('./models/*.gguf') ) |
|
LLM.extend(["Open_AIGPT-3.5-Turbo", "Google_PaLm"]) |
|
st.session_state.model = st.selectbox( 'Models', LLM ) |
|
|
|
|
|
st.subheader("Your documents") |
|
docs = st.file_uploader( |
|
"Upload File (pdf,text,csv...) and click 'Process'", accept_multiple_files=True) |
|
|
|
if st.button("Process"): |
|
with st.spinner("Processing"): |
|
|
|
doc_list = [] |
|
|
|
|
|
|
|
|
|
for file in docs: |
|
print('file - type : ', file.type) |
|
if file.type == 'text/plain': |
|
|
|
doc_list.extend(get_text_file(file)) |
|
elif file.type in ['application/octet-stream', 'application/pdf']: |
|
|
|
doc_list.extend(get_pdf_text(file)) |
|
elif file.type == 'text/csv': |
|
|
|
doc_list.extend(get_csv_file(file)) |
|
|
|
|
|
|
|
text_chunks = get_text_chunks(doc_list) |
|
|
|
|
|
vectorstore = get_vectorstore(text_chunks) |
|
|
|
|
|
st.session_state.conversation = get_conversation_chain(vectorstore) |
|
|
|
|
|
if __name__ == '__main__': |
|
command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir' |
|
|
|
|
|
try: |
|
subprocess.run(command, shell=True, check=True) |
|
print("Command executed successfully.") |
|
except subprocess.CalledProcessError as e: |
|
print(f"Error: {e}") |
|
main() |