|
import subprocess |
|
|
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import FastEmbedEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.callbacks.manager import CallbackManager |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from htmlTemplates import css, bot_template, user_template |
|
from langchain.llms import LlamaCpp |
|
from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader |
|
import tempfile |
|
from langchain.chains import RetrievalQA |
|
from langchain.prompts import PromptTemplate |
|
from langchain import hub |
|
import os |
|
import glob |
|
import gc |
|
|
|
|
|
|
|
def get_pdf_text(pdf_docs): |
|
""" |
|
Purpose: A hypothetical loader for PDF files in Python. |
|
Usage: Used to extract text or other information from PDF documents. |
|
Load Function: A load_pdf function might be used to read and extract data from a PDF file. |
|
|
|
input : pdf document path |
|
returns : extracted text |
|
""" |
|
temp_dir = tempfile.TemporaryDirectory() |
|
temp_filepath = os.path.join(temp_dir.name, pdf_docs.name) |
|
|
|
with open(temp_filepath, "wb") as f: |
|
f.write(pdf_docs.getvalue()) |
|
|
|
pdf_loader = PyPDFLoader(temp_filepath) |
|
pdf_doc = pdf_loader.load() |
|
return pdf_doc |
|
|
|
|
|
def get_text_file(text_docs): |
|
""" |
|
""" |
|
temp_dir = tempfile.TemporaryDirectory() |
|
temp_filepath = os.path.join(temp_dir.name, text_docs.name) |
|
|
|
with open(temp_filepath, "wb") as f: |
|
f.write(text_docs.getvalue()) |
|
|
|
text_loader = TextLoader(temp_filepath) |
|
text_doc = text_loader.load() |
|
return text_doc |
|
|
|
def get_csv_file(csv_docs): |
|
temp_dir = tempfile.TemporaryDirectory() |
|
temp_filepath = os.path.join(temp_dir.name, csv_docs.name) |
|
|
|
with open(temp_filepath, "wb") as f: |
|
f.write(csv_docs.getvalue()) |
|
|
|
csv_loader = CSVLoader(temp_filepath) |
|
csv_doc = csv_loader.load() |
|
return csv_doc |
|
|
|
|
|
def get_json_file(json_docs): |
|
temp_dir = tempfile.TemporaryDirectory() |
|
temp_filepath = os.path.join(temp_dir.name, json_docs.name) |
|
with open(temp_filepath, "wb") as f: |
|
f.write(json_docs.getvalue()) |
|
|
|
json_loader = JSONLoader( |
|
file_path=temp_filepath, |
|
jq_schema='.messages[].content', |
|
text_content=False |
|
) |
|
json_doc = json_loader.load() |
|
return json_doc |
|
|
|
|
|
def get_text_chunks(documents): |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=512, |
|
chunk_overlap=50, |
|
length_function=len |
|
) |
|
|
|
documents = text_splitter.split_documents(documents) |
|
return documents |
|
|
|
|
|
|
|
def get_vectorstore(text_chunks, embeddings): |
|
|
|
vectorstore = Chroma.from_documents(documents= text_chunks, |
|
embedding= st.session_state.embeddings, |
|
persist_directory= "./vectordb/") |
|
|
|
return vectorstore |
|
|
|
def get_conversation_chain(vectorstore): |
|
|
|
model_path = "models/llama-2-13b-chat.Q4_K_S.gguf" |
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
|
|
llm = LlamaCpp(model_path= model_path, |
|
n_ctx=4000, |
|
max_tokens= 500, |
|
n_gpu_layers = 50, |
|
n_batch = 512, |
|
callback_manager = callback_manager |
|
verbose=True) |
|
|
|
memory = ConversationBufferMemory( |
|
memory_key='chat_history', return_messages=True) |
|
|
|
|
|
template = """ |
|
You are a Experience human Resource Manager. When the employee asks you a question, you will have to refer the company policy and respond in a professional way. Make sure to sound Empethetic while being professional and sound like a Human! |
|
Try to summarise the content and keep the answer to the point. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
|
|
Followe the template below |
|
Example: |
|
Question : how many paid leaves do i have ? |
|
Answer : The number of paid leaves varies depending on the type of leave, like privilege leave you're entitled to a maximum of 21 days in a calendar year. Other leaves might have different entitlements. thanks for asking! |
|
make sure to add "thanks for asking!" after every answer |
|
|
|
{context} |
|
|
|
Question: {question} |
|
Answer: |
|
|
|
Just answer to the point! |
|
""" |
|
|
|
rag_prompt_custom = PromptTemplate.from_template(template) |
|
|
|
|
|
|
|
conversation_chain = RetrievalQA.from_chain_type( |
|
llm, |
|
retriever=vectorstore.as_retriever(), |
|
chain_type_kwargs={"prompt": rag_prompt_custom}, |
|
) |
|
conversation_chain.callback_manager = callback_manager |
|
conversation_chain.memory = ConversationBufferMemory() |
|
|
|
return conversation_chain |
|
|
|
|
|
def handle_userinput(): |
|
|
|
clear = False |
|
|
|
|
|
if st.button("Clear Chat history"): |
|
clear = True |
|
st.session_state.messages = [] |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}] |
|
|
|
for msg in st.session_state.messages: |
|
st.chat_message(msg["role"]).write(msg["content"]) |
|
|
|
if prompt := st.chat_input(): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
st.chat_message("user").write(prompt) |
|
if clear: |
|
st.session_state.conversation.clean() |
|
msg = st.session_state.conversation.run(prompt) |
|
print(msg) |
|
st.session_state.messages.append({"role": "assistant", "content": msg}) |
|
st.chat_message("assistant").write(msg) |
|
|
|
|
|
|
|
|
|
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30): |
|
st.markdown( |
|
f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>', |
|
unsafe_allow_html=True,) |
|
st.image(image_path, use_column_width=True, output_format='auto') |
|
|
|
|
|
def main(): |
|
load_dotenv() |
|
gc.collect() |
|
st.set_page_config(page_title="Chat with multiple Files", |
|
page_icon=":books:") |
|
st.write(css, unsafe_allow_html=True) |
|
|
|
if "conversation" not in st.session_state: |
|
st.session_state.conversation = None |
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = None |
|
|
|
st.title("π¬ Randstad HR Chatbot") |
|
st.subheader("π A HR powered by Generative AI") |
|
|
|
|
|
st.session_state.embeddings = FastEmbedEmbeddings( model_name= "BAAI/bge-small-en-v1.5", |
|
cache_dir="./embedding_model/") |
|
|
|
if len(glob.glob("./vectordb/*.sqlite3")) > 0: |
|
|
|
vectorstore = Chroma(persist_directory="./vectordb/", embedding_function=st.session_state.embeddings) |
|
st.session_state.conversation = get_conversation_chain(vectorstore) |
|
handle_userinput() |
|
|
|
with st.sidebar: |
|
add_rounded_edges() |
|
|
|
st.subheader("Your documents") |
|
docs = st.file_uploader( |
|
"Upload File (pdf,text,csv...) and click 'Process'", accept_multiple_files=True) |
|
if st.button("Process"): |
|
with st.spinner("Processing"): |
|
|
|
doc_list = [] |
|
|
|
for file in docs: |
|
print('file - type : ', file.type) |
|
if file.type == 'text/plain': |
|
|
|
doc_list.extend(get_text_file(file)) |
|
elif file.type in ['application/octet-stream', 'application/pdf']: |
|
|
|
doc_list.extend(get_pdf_text(file)) |
|
elif file.type == 'text/csv': |
|
|
|
doc_list.extend(get_csv_file(file)) |
|
elif file.type == 'application/json': |
|
|
|
doc_list.extend(get_json_file(file)) |
|
|
|
|
|
text_chunks = get_text_chunks(doc_list) |
|
|
|
|
|
vectorstore = get_vectorstore(text_chunks, st.session_state.embeddings) |
|
|
|
|
|
st.session_state.conversation = get_conversation_chain(vectorstore) |
|
|
|
|
|
if __name__ == '__main__': |
|
command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir' |
|
|
|
|
|
try: |
|
subprocess.run(command, shell=True, check=True) |
|
print("Command executed successfully.") |
|
except subprocess.CalledProcessError as e: |
|
print(f"Error: {e}") |
|
|
|
main() |
|
|