File size: 5,854 Bytes
83861a8 9e09635 83861a8 66ce967 9e09635 66ce967 9e09635 66ce967 691a796 9535c62 9e09635 d930f2c 9e09635 66ce967 9e09635 66ce967 efe5822 66ce967 9e09635 66ce967 9e09635 6661361 83861a8 9e09635 66ce967 9e09635 66ce967 9e09635 66ce967 9e09635 66ce967 ae71491 66ce967 9e09635 66ce967 9e09635 74cff61 3f7f1c2 9e09635 83861a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# HyDE + ReRank RAG for Freights Rates
# Import modules and classes
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_core.documents import Document as LangDocument
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.llms.nvidia import NVIDIA
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from llama_index.core import Document as LlamaDocument
from llama_index.core import Settings
from llama_parse import LlamaParse
import streamlit as st
import os
# Set environmental variables
nvidia_api_key = os.getenv("NVIDIA_KEY")
llamaparse_api_key = os.getenv("PARSE_KEY")
# Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
client = NVIDIA(
model="meta/llama-3.1-8b-instruct",
api_key=nvidia_api_key,
temperature=0.2,
top_p=0.7,
max_tokens=1024
)
embed_model = NVIDIAEmbedding(
model="nvidia/nv-embedqa-e5-v5",
api_key=nvidia_api_key,
truncate="NONE"
)
reranker = NVIDIARerank(
model="nvidia/nv-rerankqa-mistral-4b-v3",
api_key=nvidia_api_key,
)
# Set the NVIDIA models globally
Settings.embed_model = embed_model
Settings.llm = client
# Parse the local PDF document
parser = LlamaParse(
api_key=llamaparse_api_key,
result_type="markdown",
verbose=True
)
# Get the absolute path of the script's directory
script_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(script_dir, "FreightsDataset.pdf")
# Load the PDF document using the relative path
documents = parser.load_data(data_file)
print("Document Parsed")
# Split parsed text into chunks for embedding model
def split_text(text, max_tokens=512):
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
word_length = len(word)
if current_length + word_length + 1 > max_tokens:
chunks.append(" ".join(current_chunk))
current_chunk = [word]
current_length = word_length + 1
else:
current_chunk.append(word)
current_length += word_length + 1
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
# Generate embeddings for document chunks
all_embeddings = []
all_documents = []
for doc in documents:
text_chunks = split_text(doc.text)
for chunk in text_chunks:
embedding = embed_model.get_text_embedding(chunk)
all_embeddings.append(embedding)
all_documents.append(LlamaDocument(text=chunk))
print("Embeddings generated")
# Create and persist index with NVIDIAEmbeddings
index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model)
index.set_index_id("vector_index")
index.storage_context.persist("./storage")
print("Index created")
# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir="storage")
index = load_index_from_storage(storage_context, index_id="vector_index")
print("Index loaded")
# Initialize HyDEQueryTransform and TransformQueryEngine
hyde = HyDEQueryTransform(include_original=True)
query_engine = index.as_query_engine()
hyde_query_engine = TransformQueryEngine(query_engine, hyde)
# Query the index with HyDE and use output as LLM context
def query_model_with_context(question):
# Generate a hypothetical document using HyDE
hyde_response = hyde_query_engine.query(question)
print(f"HyDE Response: {hyde_response}")
if isinstance(hyde_response, str):
hyde_query = hyde_response
else:
hyde_query = hyde_response.response
# Use the hypothetical document to retrieve relevant documents
retriever = index.as_retriever(similarity_top_k=3)
nodes = retriever.retrieve(hyde_query)
for node in nodes:
print(node)
# Rerank the retrieved nodes
ranked_documents = reranker.compress_documents(
query=question,
documents=[LangDocument(page_content=node.text) for node in nodes]
)
# Print the most relevant and least relevant node
print(f"Most relevant node: {ranked_documents[0].page_content}")
# Use the most relevant node as context
context = ranked_documents[0].page_content
# Send context and question to the client (NVIDIA Llama 3.1 8B model)
# Construct the messages using the ChatMessage class
messages = [
ChatMessage(role=MessageRole.SYSTEM, content=context),
ChatMessage(role=MessageRole.USER, content=str(question))
]
# Call the chat method to get the response
completion = client.chat(messages)
# Process response - assuming completion is a single string or a tuple containing a string
response_text = ""
if isinstance(completion, (list, tuple)):
# Join elements of tuple/list if it's in such format
response_text = ' '.join(completion)
elif isinstance(completion, str):
# Directly assign if it's a string
response_text = completion
else:
# Fallback in case of unexpected types, convert to string
response_text = str(completion)
response_text = response_text.replace("assistant:", "Final Response:").strip()
return response_text
# Streamlit UI
st.title("Chat with HyDE and Rerank RAG Freights App")
question = st.text_input("Enter a relevant question to chat with the attached FreightsDataset file:")
if st.button("Submit"):
if question:
st.write("**RAG Response:**")
response = query_model_with_context(question)
st.write(response)
else:
st.warning("Please enter a question.")
|