Ronoh4
commited on
Commit
·
83861a8
1
Parent(s):
3f7f1c2
Changed title
Browse files
app.py
CHANGED
@@ -1,18 +1,21 @@
|
|
|
|
|
|
|
|
|
|
1 |
# Import modules and classes
|
2 |
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
|
3 |
-
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
4 |
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
|
|
|
|
|
|
|
5 |
from llama_index.core.llms import ChatMessage, MessageRole
|
6 |
from llama_index.llms.nvidia import NVIDIA
|
7 |
from llama_index.embeddings.nvidia import NVIDIAEmbedding
|
8 |
-
from llama_index.core.query_engine import TransformQueryEngine
|
9 |
-
from langchain_core.documents import Document as LangDocument
|
10 |
from llama_index.core import Document as LlamaDocument
|
11 |
from llama_index.core import Settings
|
12 |
from llama_parse import LlamaParse
|
13 |
import streamlit as st
|
14 |
import os
|
15 |
-
|
16 |
# Set environmental variables
|
17 |
nvidia_api_key = os.getenv("NVIDIA_KEY")
|
18 |
llamaparse_api_key = os.getenv("PARSE_KEY")
|
@@ -33,10 +36,12 @@ embed_model = NVIDIAEmbedding(
|
|
33 |
)
|
34 |
|
35 |
reranker = NVIDIARerank(
|
36 |
-
|
37 |
-
|
|
|
38 |
)
|
39 |
|
|
|
40 |
# Set the NVIDIA models globally
|
41 |
Settings.embed_model = embed_model
|
42 |
Settings.llm = client
|
@@ -121,10 +126,7 @@ def query_model_with_context(question):
|
|
121 |
retriever = index.as_retriever(similarity_top_k=3)
|
122 |
nodes = retriever.retrieve(hyde_query)
|
123 |
|
124 |
-
|
125 |
-
print(node)
|
126 |
-
|
127 |
-
# Rerank the retrieved documents
|
128 |
ranked_documents = reranker.compress_documents(
|
129 |
query=question,
|
130 |
documents=[LangDocument(page_content=node.text) for node in nodes]
|
@@ -146,8 +148,6 @@ def query_model_with_context(question):
|
|
146 |
# Call the chat method to get the response
|
147 |
completion = client.chat(messages)
|
148 |
|
149 |
-
print(completion)
|
150 |
-
|
151 |
# Process response - assuming completion is a single string or a tuple containing a string
|
152 |
response_text = ""
|
153 |
|
@@ -167,7 +167,7 @@ def query_model_with_context(question):
|
|
167 |
|
168 |
|
169 |
# Streamlit UI
|
170 |
-
st.title("Chat with HyDE + Rerank Freights
|
171 |
question = st.text_input("Enter a relevant question to chat with the attached FreightsDataset file:")
|
172 |
|
173 |
if st.button("Submit"):
|
@@ -176,4 +176,5 @@ if st.button("Submit"):
|
|
176 |
response = query_model_with_context(question)
|
177 |
st.write(response)
|
178 |
else:
|
179 |
-
st.warning("Please enter a question.")
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
# HyDE + ReRank RAG for Freights Rates
|
4 |
+
|
5 |
# Import modules and classes
|
6 |
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
|
|
|
7 |
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
|
8 |
+
from llama_index.core.query_engine import TransformQueryEngine
|
9 |
+
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
10 |
+
from langchain_core.documents import Document as LangDocument
|
11 |
from llama_index.core.llms import ChatMessage, MessageRole
|
12 |
from llama_index.llms.nvidia import NVIDIA
|
13 |
from llama_index.embeddings.nvidia import NVIDIAEmbedding
|
|
|
|
|
14 |
from llama_index.core import Document as LlamaDocument
|
15 |
from llama_index.core import Settings
|
16 |
from llama_parse import LlamaParse
|
17 |
import streamlit as st
|
18 |
import os
|
|
|
19 |
# Set environmental variables
|
20 |
nvidia_api_key = os.getenv("NVIDIA_KEY")
|
21 |
llamaparse_api_key = os.getenv("PARSE_KEY")
|
|
|
36 |
)
|
37 |
|
38 |
reranker = NVIDIARerank(
|
39 |
+
model="nvidia/nv_embedqa_e5-v5",
|
40 |
+
api_key=nvidia_api_key,
|
41 |
+
truncate="NONE"
|
42 |
)
|
43 |
|
44 |
+
|
45 |
# Set the NVIDIA models globally
|
46 |
Settings.embed_model = embed_model
|
47 |
Settings.llm = client
|
|
|
126 |
retriever = index.as_retriever(similarity_top_k=3)
|
127 |
nodes = retriever.retrieve(hyde_query)
|
128 |
|
129 |
+
# Rerank the retrieved nodes
|
|
|
|
|
|
|
130 |
ranked_documents = reranker.compress_documents(
|
131 |
query=question,
|
132 |
documents=[LangDocument(page_content=node.text) for node in nodes]
|
|
|
148 |
# Call the chat method to get the response
|
149 |
completion = client.chat(messages)
|
150 |
|
|
|
|
|
151 |
# Process response - assuming completion is a single string or a tuple containing a string
|
152 |
response_text = ""
|
153 |
|
|
|
167 |
|
168 |
|
169 |
# Streamlit UI
|
170 |
+
st.title("Chat with HyDE + Rerank Freights App")
|
171 |
question = st.text_input("Enter a relevant question to chat with the attached FreightsDataset file:")
|
172 |
|
173 |
if st.button("Submit"):
|
|
|
176 |
response = query_model_with_context(question)
|
177 |
st.write(response)
|
178 |
else:
|
179 |
+
st.warning("Please enter a question.")
|
180 |
+
|