File size: 4,888 Bytes
0fdded5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Import modules and classes
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.llms.nvidia import NVIDIA
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from llama_index.core.llms import ChatMessage, MessageRole
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_core.documents import Document as LangDocument
from llama_index.core import Document as LlamaDocument
from llama_index.core import Settings
from llama_parse import LlamaParse
import streamlit as st
import os

# Set environmental variables
nvidia_api_key = os.getenv("NVIDIA_KEY")
llamaparse_api_key = os.getenv("PARSE_KEY")

# Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
client = NVIDIA(
    model="meta/llama-3.1-8b-instruct",
    api_key=nvidia_api_key,
    temperature=0.2,
    top_p=0.7,
    max_tokens=1024
)

embed_model = NVIDIAEmbedding(
    model="nvidia/nv-embedqa-e5-v5", 
    api_key=nvidia_api_key, 
    truncate="NONE"
)

reranker = NVIDIARerank(
  model="nvidia/nv-rerankqa-mistral-4b-v3", 
  api_key=nvidia_api_key,
)

# Set the NVIDIA models globally
Settings.embed_model = embed_model
Settings.llm = client

# Parse the local PDF document
parser = LlamaParse(
    api_key=llamaparse_api_key,
    result_type="markdown",
    verbose=True
)

# Get the absolute path of the script's directory
script_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(script_dir, "PhilDataset.pdf")

# Load the PDF document using the relative path
documents = parser.load_data(data_file)
print("Document Parsed")

# Split parsed text into chunks for embedding model
def split_text(text, max_tokens=512):
    words = text.split()
    chunks = []
    current_chunk = []
    current_length = 0

    for word in words:
        word_length = len(word)
        if current_length + word_length + 1 > max_tokens:
            chunks.append(" ".join(current_chunk))
            current_chunk = [word]
            current_length = word_length + 1
        else:
            current_chunk.append(word)
            current_length += word_length + 1

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

# Generate embeddings for document chunks
all_embeddings = []
all_documents = []

for doc in documents:
    text_chunks = split_text(doc.text)
    for chunk in text_chunks:
        embedding = embed_model.get_text_embedding(chunk)
        all_embeddings.append(embedding)
        all_documents.append(LlamaDocument(text=chunk))
print("Embeddings generated")

# Create and persist index with NVIDIAEmbeddings
index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model)
index.set_index_id("vector_index")
index.storage_context.persist("./storage")
print("Index created")

# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir="storage")
index = load_index_from_storage(storage_context, index_id="vector_index")
print("Index loaded")

# Query the index and use output as LLM context
def query_model_with_context(question):

    retriever = index.as_retriever(similarity_top_k=3)
    nodes = retriever.retrieve(question)

    for node in nodes:
        print(node)

    # Rerank the nodes
    ranked_documents = reranker.compress_documents(
        query=question,
        documents = [LangDocument(page_content=node.text) for node in nodes]
    )

    # Print the most relevant and least relevant node
    print(f"Most relevant node: {ranked_documents[0].page_content}")

    # Use the most relevant node as context
    context = ranked_documents[0].page_content

    # Construct the messages using the ChatMessage class
    messages = [
        ChatMessage(role=MessageRole.SYSTEM, content=context),
        ChatMessage(role=MessageRole.USER, content=str(question))
    ]

    completion = client.chat(messages)
    
    # Process response - assuming completion is a single string or a tuple containing a string
    response_text = ""

    if isinstance(completion, (list, tuple)):
        # Join elements of tuple/list if it's in such format
        response_text = ' '.join(completion)
    elif isinstance(completion, str):
        # Directly assign if it's a string
        response_text = completion
    else:
        # Fallback for unexpected types, convert to string
        response_text = str(completion)
    
    response_text = response_text.replace("assistant:", "Final Response:").strip()

    return response_text


# Streamlit UI
st.title("Chat with this Rerank RAG App")
question = st.text_input("Enter a relevant question to chat with the attached PhilDataset PDF file:")

if st.button("Submit"):
    if question:
        st.write("**RAG Response:**")
        response = query_model_with_context(question)
        st.write(response)
    else:
        st.warning("Please enter a question.")