|
|
|
|
|
|
|
|
|
|
|
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage |
|
from llama_index.core.indices.query.query_transform import HyDEQueryTransform |
|
from llama_index.core.query_engine import TransformQueryEngine |
|
from langchain_nvidia_ai_endpoints import NVIDIARerank |
|
from langchain_core.documents import Document as LangDocument |
|
from llama_index.core.llms import ChatMessage, MessageRole |
|
from llama_index.llms.nvidia import NVIDIA |
|
from llama_index.embeddings.nvidia import NVIDIAEmbedding |
|
from llama_index.core import Document as LlamaDocument |
|
from llama_index.core import Settings |
|
from llama_parse import LlamaParse |
|
import streamlit as st |
|
import os |
|
|
|
nvidia_api_key = os.getenv("NVIDIA_KEY") |
|
llamaparse_api_key = os.getenv("PARSE_KEY") |
|
|
|
|
|
client = NVIDIA( |
|
model="meta/llama-3.1-8b-instruct", |
|
api_key=nvidia_api_key, |
|
temperature=0.2, |
|
top_p=0.7, |
|
max_tokens=1024 |
|
) |
|
embed_model = NVIDIAEmbedding( |
|
model="nvidia/nv-embedqa-e5-v5", |
|
api_key=nvidia_api_key, |
|
truncate="NONE" |
|
) |
|
|
|
|
|
reranker = NVIDIARerank( |
|
model="nvidia/nv-rerankqa-mistral-4b-v3", |
|
api_key=nvidia_api_key, |
|
) |
|
|
|
|
|
Settings.embed_model = embed_model |
|
Settings.llm = client |
|
|
|
|
|
parser = LlamaParse( |
|
api_key=llamaparse_api_key, |
|
result_type="markdown", |
|
verbose=True |
|
) |
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
data_file = os.path.join(script_dir, "FreightsDataset.pdf") |
|
|
|
|
|
documents = parser.load_data(data_file) |
|
print("Document Parsed") |
|
|
|
|
|
def split_text(text, max_tokens=512): |
|
words = text.split() |
|
chunks = [] |
|
current_chunk = [] |
|
current_length = 0 |
|
|
|
for word in words: |
|
word_length = len(word) |
|
if current_length + word_length + 1 > max_tokens: |
|
chunks.append(" ".join(current_chunk)) |
|
current_chunk = [word] |
|
current_length = word_length + 1 |
|
else: |
|
current_chunk.append(word) |
|
current_length += word_length + 1 |
|
|
|
if current_chunk: |
|
chunks.append(" ".join(current_chunk)) |
|
|
|
return chunks |
|
|
|
|
|
all_embeddings = [] |
|
all_documents = [] |
|
|
|
for doc in documents: |
|
text_chunks = split_text(doc.text) |
|
for chunk in text_chunks: |
|
embedding = embed_model.get_text_embedding(chunk) |
|
all_embeddings.append(embedding) |
|
all_documents.append(LlamaDocument(text=chunk)) |
|
print("Embeddings generated") |
|
|
|
|
|
index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model) |
|
index.set_index_id("vector_index") |
|
index.storage_context.persist("./storage") |
|
print("Index created") |
|
|
|
|
|
storage_context = StorageContext.from_defaults(persist_dir="storage") |
|
index = load_index_from_storage(storage_context, index_id="vector_index") |
|
print("Index loaded") |
|
|
|
|
|
hyde = HyDEQueryTransform(include_original=True) |
|
query_engine = index.as_query_engine() |
|
hyde_query_engine = TransformQueryEngine(query_engine, hyde) |
|
|
|
|
|
def query_model_with_context(question): |
|
|
|
hyde_response = hyde_query_engine.query(question) |
|
print(f"HyDE Response: {hyde_response}") |
|
|
|
if isinstance(hyde_response, str): |
|
hyde_query = hyde_response |
|
else: |
|
hyde_query = hyde_response.response |
|
|
|
|
|
retriever = index.as_retriever(similarity_top_k=3) |
|
nodes = retriever.retrieve(hyde_query) |
|
|
|
for node in nodes: |
|
print(node) |
|
|
|
|
|
ranked_documents = reranker.compress_documents( |
|
query=question, |
|
documents=[LangDocument(page_content=node.text) for node in nodes] |
|
) |
|
|
|
|
|
print(f"Most relevant node: {ranked_documents[0].page_content}") |
|
|
|
|
|
context = ranked_documents[0].page_content |
|
|
|
|
|
|
|
messages = [ |
|
ChatMessage(role=MessageRole.SYSTEM, content=context), |
|
ChatMessage(role=MessageRole.USER, content=str(question)) |
|
] |
|
|
|
|
|
completion = client.chat(messages) |
|
|
|
|
|
response_text = "" |
|
|
|
if isinstance(completion, (list, tuple)): |
|
|
|
response_text = ' '.join(completion) |
|
elif isinstance(completion, str): |
|
|
|
response_text = completion |
|
else: |
|
|
|
response_text = str(completion) |
|
|
|
response_text = response_text.replace("assistant:", "Final Response:").strip() |
|
|
|
return response_text |
|
|
|
|
|
|
|
st.title("Chat with HyDE and Rerank RAG Freights App") |
|
question = st.text_input("Enter a relevant question to chat with the attached FreightsDataset file:") |
|
|
|
if st.button("Submit"): |
|
if question: |
|
st.write("**RAG Response:**") |
|
response = query_model_with_context(question) |
|
st.write(response) |
|
else: |
|
st.warning("Please enter a question.") |
|
|
|
|