rag_citations_demo / ai_generate.py
minko186's picture
Create ai_generate.py
704db80 verified
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_core.documents import Document
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain.schema import StrOutputParser
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_anthropic import ChatAnthropic
from dotenv import load_dotenv
from langchain_core.output_parsers import XMLOutputParser
from langchain.prompts import ChatPromptTemplate
load_dotenv()
# suppress grpc and glog logs for gemini
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "2"
# RAG parameters
CHUNK_SIZE = 1024
CHUNK_OVERLAP = CHUNK_SIZE // 8
K = 10
FETCH_K = 20
llm_model_translation = {
"LLaMA 3": "llama3-70b-8192",
"OpenAI GPT 4o Mini": "gpt-4o-mini",
"OpenAI GPT 4o": "gpt-4o",
"OpenAI GPT 4": "gpt-4-turbo",
"Gemini 1.5 Pro": "gemini-1.5-pro",
"Claude Sonnet 3.5": "claude-3-5-sonnet-20240620",
}
llm_classes = {
"llama3-70b-8192": ChatGroq,
"gpt-4o-mini": ChatOpenAI,
"gpt-4o": ChatOpenAI,
"gpt-4-turbo": ChatOpenAI,
"gemini-1.5-pro": ChatGoogleGenerativeAI,
"claude-3-5-sonnet-20240620": ChatAnthropic,
}
xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, \
fulfill all the requirements of the prompt and provide citations. If a part of the generated text does \
not use any of the sources, don't put a citation for that part. Otherwise, list all sources used for that part of the text.
At the end of each relevant part, add a citation in square brackets, numbered sequentially starting from [0], regardless of the source's original ID.
Remember, you must return both the requested text and citations. A citation consists of a VERBATIM quote that \
justifies the text and a sequential number (starting from 0) for the quote's article. Return a citation for every quote across all articles \
that justify the text. Use the following format for your final output:
<cited_text>
<text></text>
<citations>
<citation><source_id></source_id><source></source><quote></quote></citation>
<citation><source_id></source_id><source></source><quote></quote></citation>
...
</citations>
</cited_text>
Here are the sources:{context}"""
xml_prompt = ChatPromptTemplate.from_messages(
[("system", xml_system), ("human", "{input}")]
)
def format_docs_xml(docs: list[Document]) -> str:
formatted = []
for i, doc in enumerate(docs):
doc_str = f"""\
<source>
<source>{doc.metadata['source']}</source>
<title>{doc.metadata['title']}</title>
<article_snippet>{doc.page_content}</article_snippet>
</source>"""
formatted.append(doc_str)
return "\n\n<sources>" + "\n".join(formatted) + "</sources>"
def citations_to_html(citations_data):
if citations_data:
html_output = "<ul>"
for index, citation in enumerate(citations_data):
source_id = citation['citation'][0]['source_id']
source = citation['citation'][1]['source']
quote = citation['citation'][2]['quote']
html_output += f"""
<li>
[{index}] - "{source}" <br>
"{quote}"
</li>
"""
html_output += "</ul>"
return html_output
return ""
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
model_name = llm_model_translation.get(model)
llm_class = llm_classes.get(model_name)
if not llm_class:
raise ValueError(f"Model {model} not supported.")
try:
llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length)
except Exception as e:
print(f"An error occurred: {e}")
llm = None
return llm
def create_db_with_langchain(path: list[str], url_content: dict):
all_docs = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
if path:
for file in path:
loader = PyMuPDFLoader(file)
data = loader.load()
# split it into chunks
docs = text_splitter.split_documents(data)
all_docs.extend(docs)
if url_content:
for url, content in url_content.items():
doc = Document(page_content=content, metadata={"source": url})
# split it into chunks
docs = text_splitter.split_documents([doc])
all_docs.extend(docs)
# print docs
for idx, doc in enumerate(all_docs):
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
db = Chroma.from_documents(all_docs, embedding_function)
return db
def generate_rag(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
db = create_db_with_langchain(path, url_content)
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
rag_prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
if all(isinstance(doc, Document) for doc in docs):
return "\n\n".join(doc.page_content for doc in docs)
else:
raise TypeError("All items in docs must be instances of Document.")
docs = retriever.get_relevant_documents(topic)
# formatted_docs = format_docs(docs)
# rag_chain = (
# {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
# )
# return rag_chain.invoke(prompt)
formatted_docs = format_docs_xml(docs)
rag_chain = (
RunnablePassthrough.assign(context=lambda _: formatted_docs)
| xml_prompt
| llm
| XMLOutputParser()
)
result = rag_chain.invoke({"input": prompt})
from pprint import pprint
pprint(result)
return result['cited_text'][0]['text'], citations_to_html(result['cited_text'][1]['citations'])
def generate_base(
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
try:
output = llm.invoke(prompt).content
return output
except Exception as e:
print(f"An error occurred while running the model: {e}")
return None
def generate(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
if path or url_content:
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
else:
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)