hyderag / app.py
Ronoh4
Add app 1 file1
74cff61
# HyDE + ReRank RAG for Freights Rates
# Import modules and classes
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_core.documents import Document as LangDocument
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.llms.nvidia import NVIDIA
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from llama_index.core import Document as LlamaDocument
from llama_index.core import Settings
from llama_parse import LlamaParse
import streamlit as st
import os
# Set environmental variables
nvidia_api_key = os.getenv("NVIDIA_KEY")
llamaparse_api_key = os.getenv("PARSE_KEY")
# Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
client = NVIDIA(
model="meta/llama-3.1-8b-instruct",
api_key=nvidia_api_key,
temperature=0.2,
top_p=0.7,
max_tokens=1024
)
embed_model = NVIDIAEmbedding(
model="nvidia/nv-embedqa-e5-v5",
api_key=nvidia_api_key,
truncate="NONE"
)
reranker = NVIDIARerank(
model="nvidia/nv-rerankqa-mistral-4b-v3",
api_key=nvidia_api_key,
)
# Set the NVIDIA models globally
Settings.embed_model = embed_model
Settings.llm = client
# Parse the local PDF document
parser = LlamaParse(
api_key=llamaparse_api_key,
result_type="markdown",
verbose=True
)
# Get the absolute path of the script's directory
script_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(script_dir, "FreightsDataset.pdf")
# Load the PDF document using the relative path
documents = parser.load_data(data_file)
print("Document Parsed")
# Split parsed text into chunks for embedding model
def split_text(text, max_tokens=512):
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
word_length = len(word)
if current_length + word_length + 1 > max_tokens:
chunks.append(" ".join(current_chunk))
current_chunk = [word]
current_length = word_length + 1
else:
current_chunk.append(word)
current_length += word_length + 1
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
# Generate embeddings for document chunks
all_embeddings = []
all_documents = []
for doc in documents:
text_chunks = split_text(doc.text)
for chunk in text_chunks:
embedding = embed_model.get_text_embedding(chunk)
all_embeddings.append(embedding)
all_documents.append(LlamaDocument(text=chunk))
print("Embeddings generated")
# Create and persist index with NVIDIAEmbeddings
index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model)
index.set_index_id("vector_index")
index.storage_context.persist("./storage")
print("Index created")
# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir="storage")
index = load_index_from_storage(storage_context, index_id="vector_index")
print("Index loaded")
# Initialize HyDEQueryTransform and TransformQueryEngine
hyde = HyDEQueryTransform(include_original=True)
query_engine = index.as_query_engine()
hyde_query_engine = TransformQueryEngine(query_engine, hyde)
# Query the index with HyDE and use output as LLM context
def query_model_with_context(question):
# Generate a hypothetical document using HyDE
hyde_response = hyde_query_engine.query(question)
print(f"HyDE Response: {hyde_response}")
if isinstance(hyde_response, str):
hyde_query = hyde_response
else:
hyde_query = hyde_response.response
# Use the hypothetical document to retrieve relevant documents
retriever = index.as_retriever(similarity_top_k=3)
nodes = retriever.retrieve(hyde_query)
for node in nodes:
print(node)
# Rerank the retrieved nodes
ranked_documents = reranker.compress_documents(
query=question,
documents=[LangDocument(page_content=node.text) for node in nodes]
)
# Print the most relevant and least relevant node
print(f"Most relevant node: {ranked_documents[0].page_content}")
# Use the most relevant node as context
context = ranked_documents[0].page_content
# Send context and question to the client (NVIDIA Llama 3.1 8B model)
# Construct the messages using the ChatMessage class
messages = [
ChatMessage(role=MessageRole.SYSTEM, content=context),
ChatMessage(role=MessageRole.USER, content=str(question))
]
# Call the chat method to get the response
completion = client.chat(messages)
# Process response - assuming completion is a single string or a tuple containing a string
response_text = ""
if isinstance(completion, (list, tuple)):
# Join elements of tuple/list if it's in such format
response_text = ' '.join(completion)
elif isinstance(completion, str):
# Directly assign if it's a string
response_text = completion
else:
# Fallback in case of unexpected types, convert to string
response_text = str(completion)
response_text = response_text.replace("assistant:", "Final Response:").strip()
return response_text
# Streamlit UI
st.title("Chat with HyDE and Rerank RAG Freights App")
question = st.text_input("Enter a relevant question to chat with the attached FreightsDataset file:")
if st.button("Submit"):
if question:
st.write("**RAG Response:**")
response = query_model_with_context(question)
st.write(response)
else:
st.warning("Please enter a question.")