Saif Rehman Nasir
Revert to llama3 8b
f81204d
raw
history blame
10.5 kB
import os
from neo4j import GraphDatabase, Result
import pandas as pd
import numpy as np
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.graphs import Neo4jGraph
from langchain_community.vectorstores import Neo4jVector
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_huggingface import HuggingFaceEndpoint
from typing import Dict, Any
from tqdm import tqdm
from transformers import AutoTokenizer
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
vector_index = os.getenv("VECTOR_INDEX")
chat_llm = HuggingFaceEndpoint(
# repo_id="HuggingFaceH4/zephyr-7b-beta",
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
task="text-generation",
max_new_tokens=4096,
do_sample=False,
)
# global_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
def local_retriever(query: str):
topChunks = 3
topCommunities = 3
topOutsideRels = 10
topInsideRels = 10
topEntities = 10
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
try:
lc_retrieval_query = """
WITH collect(node) as nodes
// Entity - Text Unit Mapping
WITH
collect {
UNWIND nodes as n
MATCH (n)<-[:HAS_ENTITY]->(c:__Chunk__)
WITH c, count(distinct n) as freq
RETURN c.text AS chunkText
ORDER BY freq DESC
LIMIT $topChunks
} AS text_mapping,
// Entity - Report Mapping
collect {
UNWIND nodes as n
MATCH (n)-[:IN_COMMUNITY]->(c:__Community__)
WITH c, c.rank as rank, c.weight AS weight
RETURN c.summary
ORDER BY rank, weight DESC
LIMIT $topCommunities
} AS report_mapping,
// Outside Relationships
collect {
UNWIND nodes as n
MATCH (n)-[r:RELATED]-(m)
WHERE NOT m IN nodes
RETURN r.description AS descriptionText
ORDER BY r.rank, r.weight DESC
LIMIT $topOutsideRels
} as outsideRels,
// Inside Relationships
collect {
UNWIND nodes as n
MATCH (n)-[r:RELATED]-(m)
WHERE m IN nodes
RETURN r.description AS descriptionText
ORDER BY r.rank, r.weight DESC
LIMIT $topInsideRels
} as insideRels,
// Entities description
collect {
UNWIND nodes as n
RETURN n.description AS descriptionText
} as entities
// We don't have covariates or claims here
RETURN {Chunks: text_mapping, Reports: report_mapping,
Relationships: outsideRels + insideRels,
Entities: entities} AS text, 1.0 AS score, {} AS metadata
"""
embedding_model_name = "nomic-ai/nomic-embed-text-v1"
embedding_model_kwargs = {"device": "cpu", "trust_remote_code": True}
encode_kwargs = {"normalize_embeddings": True}
embedding_model = HuggingFaceBgeEmbeddings(
model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs,
encode_kwargs=encode_kwargs,
)
lc_vector = Neo4jVector.from_existing_index(
embedding_model,
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
index_name=vector_index,
retrieval_query=lc_retrieval_query,
)
docs = lc_vector.similarity_search(
query,
k=topEntities,
params={
"topChunks": topChunks,
"topCommunities": topCommunities,
"topOutsideRels": topOutsideRels,
"topInsideRels": topInsideRels,
},
)
return docs[0]
except Exception as err:
return f"Error: {err}"
finally:
try:
driver.close()
except Exception as e:
print(f"Error closing driver: {e}")
def global_retriever(query: str, level: int, response_type: str):
MAP_SYSTEM_PROMPT = """
---Role---
You are a helpful assistant responding to questions about data in the tables provided.
---Goal---
Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables.
You should use the data provided in the data tables below as the primary context for generating the response.
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
Each key point in the response should have the following element:
- Description: A comprehensive description of the point.
- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.
The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
Points supported by data should list the relevant reports as references as follows:
"This is an example sentence supported by data references [Data: Reports (report ids)]"
**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables.
Do not include information where the supporting evidence for it is not provided. Always start with {{ and end with }}.
The response can only be JSON formatted. Do not add any text before or after the JSON-formatted string in the output.
The response should adhere to the following format:
{{
"points": [
{{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}},
{{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}}
]
}}
---Data tables---
"""
map_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
MAP_SYSTEM_PROMPT,
),
("system", "{context_data}"),
(
"human",
"{question}",
),
]
)
map_chain = map_prompt | chat_llm | StrOutputParser()
REDUCE_SYSTEM_PROMPT = """
---Role---
You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts.
---Goal---
Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.
Note that the analysts' reports provided below are ranked in the **descending order of importance**.
If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.
The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process.
**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.
Do not include information where the supporting evidence for it is not provided. Style the response in markdown.
---Target response length and format---
{response_type}
---Analyst Reports---
{report_data}
Add sections and commentary to the response as appropriate for the length and format. Do not add references in your answer.
---Real Data---
"""
reduce_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
REDUCE_SYSTEM_PROMPT,
),
(
"human",
"{question}",
),
]
)
reduce_chain = reduce_prompt | chat_llm | StrOutputParser()
graph = Neo4jGraph(
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
refresh_schema=False,
)
community_data = graph.query(
"""
MATCH (c:__Community__)
WHERE c.level = $level
RETURN c.full_content AS output
""",
params={"level": level},
)
# print(community_data)
intermediate_results = []
i = 0
for community in tqdm(community_data[:2], desc="Processing communities"):
intermediate_response = map_chain.invoke(
{"question": query, "context_data": community["output"]}
)
intermediate_results.append(intermediate_response)
i += 1
print(intermediate_results)
###Debug####
# tokens = global_tokenizer(intermediate_results)
# print(f"Number of input tokens: {len(tokens)}")
###Debug###
final_response = reduce_chain.invoke(
{
"report_data": intermediate_results,
"question": query,
"response_type": response_type,
}
)
return final_response