Ley_Fill7
commited on
Commit
·
66ce967
1
Parent(s):
d522a07
Updated to include llamaindex nvidia integrations
Browse files
app.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
# Import modules and classes
|
2 |
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
|
3 |
-
from langchain_nvidia_ai_endpoints import
|
4 |
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
|
|
|
|
|
|
|
5 |
from llama_index.core.embeddings.utils import resolve_embed_model
|
6 |
from llama_index.core.query_engine import TransformQueryEngine
|
7 |
from langchain_core.documents import Document as LangDocument
|
@@ -16,7 +19,7 @@ nvidia_api_key = os.getenv("NVIDIA_KEY")
|
|
16 |
llamaparse_api_key = os.getenv("PARSE_KEY")
|
17 |
|
18 |
# Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
|
19 |
-
client =
|
20 |
model="meta/llama-3.1-8b-instruct",
|
21 |
api_key=nvidia_api_key,
|
22 |
temperature=0.2,
|
@@ -24,16 +27,7 @@ client = ChatNVIDIA(
|
|
24 |
max_tokens=1024
|
25 |
)
|
26 |
|
27 |
-
|
28 |
-
def custom_resolve_embed_model(embed_model):
|
29 |
-
if isinstance(embed_model, NVIDIAEmbeddings):
|
30 |
-
return embed_model
|
31 |
-
embed_model = resolve_embed_model(embed_model)
|
32 |
-
if hasattr(embed_model, 'callback_manager'):
|
33 |
-
embed_model.callback_manager = Settings.callback_manager
|
34 |
-
return embed_model
|
35 |
-
|
36 |
-
embed_model = NVIDIAEmbeddings(
|
37 |
model="nvidia/nv-embedqa-e5-v5",
|
38 |
api_key=nvidia_api_key,
|
39 |
truncate="NONE"
|
@@ -45,7 +39,7 @@ reranker = NVIDIARerank(
|
|
45 |
)
|
46 |
|
47 |
# Set the NVIDIA models globally
|
48 |
-
Settings.embed_model =
|
49 |
Settings.llm = client
|
50 |
|
51 |
# Parse the local PDF document
|
@@ -55,7 +49,12 @@ parser = LlamaParse(
|
|
55 |
verbose=True
|
56 |
)
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
print("Document Parsed")
|
60 |
|
61 |
# Split parsed text into chunks for embedding model
|
@@ -87,7 +86,7 @@ all_documents = []
|
|
87 |
for doc in documents:
|
88 |
text_chunks = split_text(doc.text)
|
89 |
for chunk in text_chunks:
|
90 |
-
embedding = embed_model.
|
91 |
all_embeddings.append(embedding)
|
92 |
all_documents.append(LlamaDocument(text=chunk))
|
93 |
print("Embeddings generated")
|
@@ -139,18 +138,32 @@ def query_model_with_context(question):
|
|
139 |
context = ranked_documents[0].page_content
|
140 |
|
141 |
# Send context and question to the client (NVIDIA Llama 3.1 8B model)
|
|
|
142 |
messages = [
|
143 |
-
|
144 |
-
|
145 |
]
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
149 |
response_text = ""
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
|
|
|
|
154 |
return response_text
|
155 |
|
156 |
|
|
|
1 |
# Import modules and classes
|
2 |
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
|
3 |
+
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
4 |
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
|
5 |
+
from llama_index.core.llms import ChatMessage, MessageRole
|
6 |
+
from llama_index.llms.nvidia import NVIDIA
|
7 |
+
from llama_index.embeddings.nvidia import NVIDIAEmbedding
|
8 |
from llama_index.core.embeddings.utils import resolve_embed_model
|
9 |
from llama_index.core.query_engine import TransformQueryEngine
|
10 |
from langchain_core.documents import Document as LangDocument
|
|
|
19 |
llamaparse_api_key = os.getenv("PARSE_KEY")
|
20 |
|
21 |
# Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
|
22 |
+
client = NVIDIA(
|
23 |
model="meta/llama-3.1-8b-instruct",
|
24 |
api_key=nvidia_api_key,
|
25 |
temperature=0.2,
|
|
|
27 |
max_tokens=1024
|
28 |
)
|
29 |
|
30 |
+
embed_model = NVIDIAEmbedding(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
model="nvidia/nv-embedqa-e5-v5",
|
32 |
api_key=nvidia_api_key,
|
33 |
truncate="NONE"
|
|
|
39 |
)
|
40 |
|
41 |
# Set the NVIDIA models globally
|
42 |
+
Settings.embed_model = embed_model
|
43 |
Settings.llm = client
|
44 |
|
45 |
# Parse the local PDF document
|
|
|
49 |
verbose=True
|
50 |
)
|
51 |
|
52 |
+
# Get the absolute path of the script's directory
|
53 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
54 |
+
data_file = os.path.join(script_dir, "PhilDataset.pdf")
|
55 |
+
|
56 |
+
# Load the PDF document using the relative path
|
57 |
+
documents = parser.load_data(data_file)
|
58 |
print("Document Parsed")
|
59 |
|
60 |
# Split parsed text into chunks for embedding model
|
|
|
86 |
for doc in documents:
|
87 |
text_chunks = split_text(doc.text)
|
88 |
for chunk in text_chunks:
|
89 |
+
embedding = embed_model.get_text_embedding(chunk)
|
90 |
all_embeddings.append(embedding)
|
91 |
all_documents.append(LlamaDocument(text=chunk))
|
92 |
print("Embeddings generated")
|
|
|
138 |
context = ranked_documents[0].page_content
|
139 |
|
140 |
# Send context and question to the client (NVIDIA Llama 3.1 8B model)
|
141 |
+
# Construct the messages using the ChatMessage class
|
142 |
messages = [
|
143 |
+
ChatMessage(role=MessageRole.SYSTEM, content=context),
|
144 |
+
ChatMessage(role=MessageRole.USER, content=str(question))
|
145 |
]
|
146 |
+
|
147 |
+
# Call the chat method to get the response
|
148 |
+
completion = client.chat(messages)
|
149 |
+
|
150 |
+
print(completion)
|
151 |
+
|
152 |
+
# Process response - assuming completion is a single string or a tuple containing a string
|
153 |
response_text = ""
|
154 |
+
|
155 |
+
if isinstance(completion, (list, tuple)):
|
156 |
+
# Join elements of tuple/list if it's in such format
|
157 |
+
response_text = ' '.join(completion)
|
158 |
+
elif isinstance(completion, str):
|
159 |
+
# Directly assign if it's a string
|
160 |
+
response_text = completion
|
161 |
+
else:
|
162 |
+
# Fallback for unexpected types, convert to string
|
163 |
+
response_text = str(completion)
|
164 |
|
165 |
+
response_text = response_text.replace("assistant:", "Final Response:").strip()
|
166 |
+
|
167 |
return response_text
|
168 |
|
169 |
|