File size: 4,888 Bytes
0fdded5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# Import modules and classes
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.llms.nvidia import NVIDIA
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from llama_index.core.llms import ChatMessage, MessageRole
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_core.documents import Document as LangDocument
from llama_index.core import Document as LlamaDocument
from llama_index.core import Settings
from llama_parse import LlamaParse
import streamlit as st
import os
# Set environmental variables
nvidia_api_key = os.getenv("NVIDIA_KEY")
llamaparse_api_key = os.getenv("PARSE_KEY")
# Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
client = NVIDIA(
model="meta/llama-3.1-8b-instruct",
api_key=nvidia_api_key,
temperature=0.2,
top_p=0.7,
max_tokens=1024
)
embed_model = NVIDIAEmbedding(
model="nvidia/nv-embedqa-e5-v5",
api_key=nvidia_api_key,
truncate="NONE"
)
reranker = NVIDIARerank(
model="nvidia/nv-rerankqa-mistral-4b-v3",
api_key=nvidia_api_key,
)
# Set the NVIDIA models globally
Settings.embed_model = embed_model
Settings.llm = client
# Parse the local PDF document
parser = LlamaParse(
api_key=llamaparse_api_key,
result_type="markdown",
verbose=True
)
# Get the absolute path of the script's directory
script_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(script_dir, "PhilDataset.pdf")
# Load the PDF document using the relative path
documents = parser.load_data(data_file)
print("Document Parsed")
# Split parsed text into chunks for embedding model
def split_text(text, max_tokens=512):
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
word_length = len(word)
if current_length + word_length + 1 > max_tokens:
chunks.append(" ".join(current_chunk))
current_chunk = [word]
current_length = word_length + 1
else:
current_chunk.append(word)
current_length += word_length + 1
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
# Generate embeddings for document chunks
all_embeddings = []
all_documents = []
for doc in documents:
text_chunks = split_text(doc.text)
for chunk in text_chunks:
embedding = embed_model.get_text_embedding(chunk)
all_embeddings.append(embedding)
all_documents.append(LlamaDocument(text=chunk))
print("Embeddings generated")
# Create and persist index with NVIDIAEmbeddings
index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model)
index.set_index_id("vector_index")
index.storage_context.persist("./storage")
print("Index created")
# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir="storage")
index = load_index_from_storage(storage_context, index_id="vector_index")
print("Index loaded")
# Query the index and use output as LLM context
def query_model_with_context(question):
retriever = index.as_retriever(similarity_top_k=3)
nodes = retriever.retrieve(question)
for node in nodes:
print(node)
# Rerank the nodes
ranked_documents = reranker.compress_documents(
query=question,
documents = [LangDocument(page_content=node.text) for node in nodes]
)
# Print the most relevant and least relevant node
print(f"Most relevant node: {ranked_documents[0].page_content}")
# Use the most relevant node as context
context = ranked_documents[0].page_content
# Construct the messages using the ChatMessage class
messages = [
ChatMessage(role=MessageRole.SYSTEM, content=context),
ChatMessage(role=MessageRole.USER, content=str(question))
]
completion = client.chat(messages)
# Process response - assuming completion is a single string or a tuple containing a string
response_text = ""
if isinstance(completion, (list, tuple)):
# Join elements of tuple/list if it's in such format
response_text = ' '.join(completion)
elif isinstance(completion, str):
# Directly assign if it's a string
response_text = completion
else:
# Fallback for unexpected types, convert to string
response_text = str(completion)
response_text = response_text.replace("assistant:", "Final Response:").strip()
return response_text
# Streamlit UI
st.title("Chat with this Rerank RAG App")
question = st.text_input("Enter a relevant question to chat with the attached PhilDataset PDF file:")
if st.button("Submit"):
if question:
st.write("**RAG Response:**")
response = query_model_with_context(question)
st.write(response)
else:
st.warning("Please enter a question.")
|