paper_reader / app.py
parambharat's picture
chore: improve rag pipeline
049ff35
raw
history blame
1.24 kB
import os
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
import streamlit as st
import weave
from rag.rag import SimpleRAGPipeline
st.set_page_config(
page_title="Chat with the Llama 3 paper!",
page_icon="πŸ¦™",
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
)
wandb_api_key = st.sidebar.text_input("WANDB_API_KEY", type="password")
if len(wandb_api_key) >= 10:
os.environ["WANDB_API_KEY"] = wandb_api_key
else:
st.stop()
WANDB_PROJECT = "paper_reader"
weave.init(f"{WANDB_PROJECT}")
st.title("Chat with the Llama 3 paper πŸ’¬πŸ¦™")
@st.cache_resource(show_spinner=False)
def load_rag_pipeline():
rag_pipeline = SimpleRAGPipeline()
rag_pipeline.build_query_engine()
return rag_pipeline
if "rag_pipeline" not in st.session_state.keys():
st.session_state.rag_pipeline = load_rag_pipeline()
rag_pipeline = st.session_state["rag_pipeline"]
def generate_response(query):
response = rag_pipeline.predict(query)
st.write_stream(response.response_gen)
with st.form("my_form"):
query = st.text_area("Ask your question about the Llama 3 paper here:")
submitted = st.form_submit_button("Submit")
if submitted:
generate_response(query)