File size: 3,164 Bytes
188a44f
 
3abd1c3
188a44f
 
 
 
 
 
 
 
3abd1c3
188a44f
 
 
3abd1c3
 
 
188a44f
3abd1c3
188a44f
 
 
 
3abd1c3
b3f99e2
188a44f
 
3abd1c3
188a44f
 
b3f99e2
188a44f
 
 
 
 
 
 
 
 
 
 
 
 
 
3abd1c3
188a44f
3abd1c3
188a44f
3abd1c3
188a44f
3abd1c3
188a44f
b3f99e2
3abd1c3
188a44f
 
 
 
 
 
 
 
3abd1c3
188a44f
 
 
 
 
 
 
 
 
3abd1c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import random
from app_config import SYSTEM_PROMPT, NLP_MODEL_NAME, NUMBER_OF_VECTORS_FOR_RAG, NLP_MODEL_TEMPERATURE, NLP_MODEL_MAX_TOKENS, VECTOR_MAX_TOKENS, my_vector_store, chat, tiktoken_len
from langchain.memory import ConversationSummaryBufferMemory
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from pathlib import Path
import os

env_path = Path('.') / '.env'
load_dotenv(dotenv_path=env_path)

# Initialize vector store and LLM outside session state
retriever = my_vector_store.as_retriever(k=NUMBER_OF_VECTORS_FOR_RAG)
llm = ChatGroq(temperature=NLP_MODEL_TEMPERATURE, groq_api_key=str(os.getenv('GROQ_API_KEY')), model_name=NLP_MODEL_NAME)

def response_generator(prompt: str) -> str:
    try:
        docs = retriever.invoke(prompt)
        my_context = [doc.page_content for doc in docs]
        my_context = '\n\n'.join(my_context)
        system_message = SystemMessage(content=SYSTEM_PROMPT.format(context=my_context, previous_message_summary=st.session_state.rag_memory.moving_summary_buffer))
        print(system_message)
        chat_messages = (system_message + st.session_state.rag_memory.chat_memory.messages + HumanMessage(content=prompt)).messages
        print("total tokens: ", tiktoken_len(str(chat_messages)))
        response = llm.invoke(chat_messages)
        return response.content
    except Exception as error:
        print(error, "ERROR")
        return "Oops! something went wrong, please try again."

st.markdown(
    """
<style>
    .st-emotion-cache-janbn0 {
        flex-direction: row-reverse;
        text-align: right;
    }
</style>
""",
    unsafe_allow_html=True,
)

# Initialize session state
if "messages" not in st.session_state:
    st.session_state.messages = [{"role": "system", "content": SYSTEM_PROMPT}]
if "rag_memory" not in st.session_state:
    st.session_state.rag_memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=5000)
if "retriever" not in st.session_state:
    st.session_state.retriever = retriever

st.title("Insurance Bot")
container = st.container(height=600)
for message in st.session_state.messages:
    if message["role"] != "system":
        with container.chat_message(message["role"]):
            st.write(message["content"])

if prompt := st.chat_input("Enter your query here... "):
    with container.chat_message("user"):
        st.write(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})
    
    with container.chat_message("assistant"):  
        response = response_generator(prompt=prompt)
        print("******************************************************** Response ********************************************************")
        print("MY RESPONSE IS:", response)
        st.write(response)
    
    print("Response is:", response)
    st.session_state.rag_memory.save_context({'input': prompt}, {'output': response})
    st.session_state.messages.append({"role": "assistant", "content": response})