|
|
|
|
|
import streamlit as st |
|
from streamlit_lottie import st_lottie, st_lottie_spinner |
|
import os |
|
from pathlib import Path |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
st.set_page_config(layout="wide", |
|
page_title="RAG With Llama 3", |
|
page_icon="Lottie Animations/LlamaIcon.jpeg") |
|
|
|
|
|
from _helper_functions import loadLottieFile |
|
from _helper_functions import initialRampUp |
|
from _helper_functions import navigationBar |
|
from _helper_functions import removedOrAdded |
|
from _helper_functions import buildVectorDatabase |
|
from _helper_functions import RetrievalChainGenerator |
|
from _helper_functions import loadItOnce |
|
|
|
|
|
from _helper_functions import display_main_title |
|
from _helper_functions import display_alert_note |
|
from _helper_functions import display_attention_text |
|
from _helper_functions import display_custom_arrow |
|
from _helper_functions import display_heading_box |
|
from _helper_functions import display_error_message |
|
from _helper_functions import display_small_text |
|
from _helper_functions import display_response_message |
|
from _helper_functions import display_question_box |
|
from _helper_functions import display_allCitations |
|
|
|
|
|
cwd = Path.cwd() |
|
filePath = cwd / "Lottie Animations" |
|
llama3 = loadLottieFile(filePath / "llama3.json") |
|
finetuning = loadLottieFile(filePath / "finetuning.json") |
|
forecasting = loadLottieFile(filePath / "forecasting.json") |
|
buildingDatabase = loadLottieFile(filePath / "buildingDatabase.json") |
|
fancyload = loadLottieFile(filePath / "fancyloading.json") |
|
citations = loadLottieFile(filePath / "citations.json") |
|
vdbList = loadLottieFile(filePath / "knowledgeBase1.json") |
|
|
|
|
|
if 'initialRampUp' not in st.session_state: |
|
st.session_state.initialRampUp = True |
|
|
|
|
|
|
|
display_main_title("Let's Chat With Llama 3!!!", st) |
|
|
|
|
|
selected_option = navigationBar() |
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
if selected_option == "Retrieval Augmented Generation": |
|
|
|
if st.session_state.initialRampUp: |
|
initialRampUp(llamaAnimation=llama3) |
|
st.session_state.initialRampUp = False |
|
|
|
|
|
with st.container(): |
|
leftCol, upld, rightCol = st.columns((3,4,3)) |
|
|
|
|
|
display_heading_box(message = "Knowledge Base Contents", container= rightCol) |
|
loadItOnce(container=rightCol, animation=vdbList, height=200, quality='low') |
|
|
|
|
|
display_heading_box(message = "Citations for responses", container= leftCol) |
|
loadItOnce(container=leftCol, animation=citations, height=200, quality='low') |
|
|
|
upld.markdown("#") |
|
upld.markdown("#") |
|
|
|
display_alert_note(message="Note: \ |
|
Multiple Files with same names will be considered unique while constructing Vector Embeddings \ |
|
It takes a little bit of time for Vector Embeddings to be built, BE PATIENT!", container= upld) |
|
upld.markdown("#") |
|
upld.markdown("#") |
|
|
|
|
|
display_attention_text(text="Build your Knowledge Base (Vector DB)!", container=upld) |
|
|
|
with upld.container(): |
|
|
|
upldLeft, upldcenter, upldRight = st.columns((1,5,1)) |
|
|
|
|
|
upldLeft.markdown("###") |
|
upldLeft.markdown("###") |
|
display_custom_arrow(direction="right", container=upldLeft) |
|
|
|
uploadedFiles = upldcenter.file_uploader(label= "Upload or Add Documents", |
|
type=['pdf', 'txt'], |
|
accept_multiple_files=True, |
|
key="fileUpload" |
|
) |
|
|
|
upldRight.markdown("###") |
|
upldRight.markdown("###") |
|
display_custom_arrow(direction="left", container=upldRight) |
|
|
|
|
|
|
|
queryInputs = upldcenter.text_input(label="Type keywords to enhance augmented search & generation! Separate each keyword with ';'", |
|
placeholder="Domain Information will be scrapped from wikipedia based on key words you enter", |
|
) |
|
|
|
|
|
|
|
rem = removedOrAdded(uploadedFiles) |
|
st.session_state.filesUploadedRecords = uploadedFiles |
|
|
|
if len(rem) > 0: |
|
fileName = st.session_state.vdbBuilt.pop(list(rem.keys())[0], None) |
|
if fileName is not None: |
|
|
|
|
|
|
|
|
|
buildVectorDatabase(files=str(fileName), addOrRemove=False, query= []) |
|
|
|
|
|
with upld.container(): |
|
|
|
_, crtOrAd, _ = st.columns((2,2,2)) |
|
|
|
|
|
createOrAdd = crtOrAd.button("Create/Add to the knowledge base") |
|
|
|
|
|
if createOrAdd: |
|
with st_lottie_spinner(buildingDatabase , height=700): |
|
fileNames = [] |
|
wikiQueries = [] |
|
|
|
|
|
if queryInputs: |
|
|
|
|
|
|
|
for w in queryInputs.split(';'): |
|
if 'Keyword ; ' + w.strip() not in st.session_state.vdbBuilt.keys() and len(w.strip()) > 0: |
|
st.session_state.vdbBuilt['Keyword ; ' + w.strip()] = None |
|
wikiQueries.append(w.strip()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for file in uploadedFiles: |
|
if file.file_id not in st.session_state.vdbBuilt.keys(): |
|
fileName = "Dump/" + file.file_id + "---" + file.name |
|
with open(fileName, "wb") as f: |
|
f.write(file.getvalue()) |
|
fileNames.append(fileName) |
|
st.session_state.vdbBuilt[file.file_id] = fileName |
|
|
|
|
|
if len(fileNames) > 0 or len(wikiQueries) > 0: |
|
buildVectorDatabase(files= fileNames, addOrRemove=True, query = wikiQueries) |
|
|
|
else: |
|
|
|
display_error_message(message= "You have no new files or Keywords to create/add vector databases!", container=upld) |
|
|
|
|
|
keyWordOptions = [key.split(';')[-1].strip() for key, _ in st.session_state.vdbBuilt.items() if key.startswith("Keyword ; ")] |
|
optionsMessage = "You can cross off certain keywords from below if you'd prefer to remove contents relevant to your entered keywords, be removed from knowledge base" if len(keyWordOptions)>0 \ |
|
else "Your Knowledge Base does not have any keywords based contents scrapped from internet" |
|
upldcenter.multiselect( |
|
label= optionsMessage, |
|
options=keyWordOptions, |
|
default=keyWordOptions, |
|
key='keywordsDBOptions') |
|
|
|
popKeys = None |
|
for key, value in st.session_state.vdbBuilt.items(): |
|
val = key.split(";")[-1].strip() if key.startswith('Keyword ; ') else None |
|
if val is not None and val not in st.session_state.keywordsDBOptions: |
|
popKeys = key |
|
break |
|
if popKeys: |
|
buildVectorDatabase(files=None, addOrRemove=False, query=st.session_state.vdbBuilt[popKeys]) |
|
_ = st.session_state.vdbBuilt.pop(popKeys, None) |
|
st.rerun() |
|
|
|
|
|
for key, value in st.session_state.vdbBuilt.items(): |
|
val = key.split(";")[-1].strip() if key.startswith('Keyword ; ') else value.split("---")[-1].strip() |
|
display_small_text(val, rightCol) |
|
|
|
|
|
with upld.container(): |
|
st.markdown("###") |
|
st.divider() |
|
st.markdown("###") |
|
_, c1,c2, _ = st.columns((2,1,2,2)) |
|
loadItOnce(container=c1, animation=llama3, height=150, quality='low') |
|
display_question_box(c2) |
|
|
|
if len(st.session_state.vdbBuilt) == 0: |
|
st.chat_input(placeholder="Type your query here once you have built vector databases!", |
|
disabled=True) |
|
else: |
|
query = st.chat_input(placeholder="Type your query here once you have built vector databases!", |
|
disabled=False) |
|
generatorLlama3_8b = RetrievalChainGenerator(model_name=os.environ['LLAMA3MODEL8B'], vectorStore=st.session_state.vectorDatabase) |
|
|
|
if query: |
|
with st.container(): |
|
with st_lottie_spinner(fancyload, height=400): |
|
response = generatorLlama3_8b.chain.invoke({"input": query}) |
|
|
|
display_response_message(response['answer'], upld) |
|
|
|
display_allCitations(response, leftCol) |
|
|
|
|
|
|
|
elif selected_option == "Fine Tuning LLMs (Coming Soon)": |
|
st_lottie(finetuning, quality='medium', height=700) |
|
|
|
|
|
elif selected_option == "Forecasting LLMs (Coming Soon)": |
|
st_lottie(forecasting, quality='high', height=700) |
|
|
|
else: |
|
pass |
|
|