File size: 5,854 Bytes
83861a8
 
 
 
9e09635
 
 
83861a8
 
 
66ce967
 
 
9e09635
 
 
 
 
 
 
 
 
 
66ce967
9e09635
 
 
 
 
 
66ce967
691a796
 
 
 
 
9535c62
9e09635
d930f2c
 
9e09635
 
 
66ce967
9e09635
 
 
 
 
 
 
 
 
66ce967
 
efe5822
66ce967
 
 
9e09635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66ce967
9e09635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6661361
 
 
83861a8
9e09635
 
 
 
 
 
 
 
 
 
 
 
66ce967
9e09635
66ce967
 
9e09635
66ce967
 
 
 
 
9e09635
66ce967
 
 
 
 
 
 
 
ae71491
66ce967
9e09635
66ce967
 
9e09635
 
 
 
74cff61
3f7f1c2
9e09635
 
 
 
 
 
 
83861a8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
                      
                      
                          # HyDE + ReRank RAG for Freights Rates

# Import modules and classes
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_core.documents import Document as LangDocument
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.llms.nvidia import NVIDIA
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from llama_index.core import Document as LlamaDocument
from llama_index.core import Settings
from llama_parse import LlamaParse
import streamlit as st
import os
# Set environmental variables
nvidia_api_key = os.getenv("NVIDIA_KEY")
llamaparse_api_key = os.getenv("PARSE_KEY")

# Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
client = NVIDIA(
    model="meta/llama-3.1-8b-instruct",
    api_key=nvidia_api_key,
    temperature=0.2,
    top_p=0.7,
    max_tokens=1024
)
embed_model = NVIDIAEmbedding(
    model="nvidia/nv-embedqa-e5-v5", 
    api_key=nvidia_api_key, 
    truncate="NONE"
)


reranker = NVIDIARerank(
  model="nvidia/nv-rerankqa-mistral-4b-v3", 
  api_key=nvidia_api_key,
)

# Set the NVIDIA models globally
Settings.embed_model = embed_model
Settings.llm = client

# Parse the local PDF document
parser = LlamaParse(
    api_key=llamaparse_api_key,
    result_type="markdown",
    verbose=True
)

# Get the absolute path of the script's directory
script_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(script_dir, "FreightsDataset.pdf")

# Load the PDF document using the relative path
documents = parser.load_data(data_file)
print("Document Parsed")

# Split parsed text into chunks for embedding model
def split_text(text, max_tokens=512):
    words = text.split()
    chunks = []
    current_chunk = []
    current_length = 0

    for word in words:
        word_length = len(word)
        if current_length + word_length + 1 > max_tokens:
            chunks.append(" ".join(current_chunk))
            current_chunk = [word]
            current_length = word_length + 1
        else:
            current_chunk.append(word)
            current_length += word_length + 1

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

# Generate embeddings for document chunks
all_embeddings = []
all_documents = []

for doc in documents:
    text_chunks = split_text(doc.text)
    for chunk in text_chunks:
        embedding = embed_model.get_text_embedding(chunk)
        all_embeddings.append(embedding)
        all_documents.append(LlamaDocument(text=chunk))
print("Embeddings generated")

# Create and persist index with NVIDIAEmbeddings
index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model)
index.set_index_id("vector_index")
index.storage_context.persist("./storage")
print("Index created")

# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir="storage")
index = load_index_from_storage(storage_context, index_id="vector_index")
print("Index loaded")

# Initialize HyDEQueryTransform and TransformQueryEngine
hyde = HyDEQueryTransform(include_original=True)
query_engine = index.as_query_engine()
hyde_query_engine = TransformQueryEngine(query_engine, hyde)

# Query the index with HyDE and use output as LLM context
def query_model_with_context(question):
    # Generate a hypothetical document using HyDE
    hyde_response = hyde_query_engine.query(question)
    print(f"HyDE Response: {hyde_response}")

    if isinstance(hyde_response, str):
        hyde_query = hyde_response
    else:
        hyde_query = hyde_response.response

    # Use the hypothetical document to retrieve relevant documents
    retriever = index.as_retriever(similarity_top_k=3)
    nodes = retriever.retrieve(hyde_query)

    for node in nodes:
        print(node)

    # Rerank the retrieved nodes
    ranked_documents = reranker.compress_documents(
        query=question,
        documents=[LangDocument(page_content=node.text) for node in nodes]
    )

    # Print the most relevant and least relevant node
    print(f"Most relevant node: {ranked_documents[0].page_content}")

    # Use the most relevant node as context
    context = ranked_documents[0].page_content

    # Send context and question to the client (NVIDIA Llama 3.1 8B model)
    # Construct the messages using the ChatMessage class
    messages = [
        ChatMessage(role=MessageRole.SYSTEM, content=context),
        ChatMessage(role=MessageRole.USER, content=str(question))
    ]

    # Call the chat method to get the response
    completion = client.chat(messages)

    # Process response - assuming completion is a single string or a tuple containing a string
    response_text = ""

    if isinstance(completion, (list, tuple)):
        # Join elements of tuple/list if it's in such format
        response_text = ' '.join(completion)
    elif isinstance(completion, str):
        # Directly assign if it's a string
        response_text = completion
    else:
        # Fallback in case of unexpected types, convert to string
        response_text = str(completion)
    
    response_text = response_text.replace("assistant:", "Final Response:").strip()

    return response_text


# Streamlit UI
st.title("Chat with HyDE and Rerank RAG Freights App")
question = st.text_input("Enter a relevant question to chat with the attached FreightsDataset file:")

if st.button("Submit"):
    if question:
        st.write("**RAG Response:**")
        response = query_model_with_context(question)
        st.write(response)
    else:
        st.warning("Please enter a question.")