Spaces:
Running
Running
import os | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_core.documents import Document | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain.schema import StrOutputParser | |
from langchain_community.vectorstores import Chroma | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain import hub | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_anthropic import ChatAnthropic | |
from dotenv import load_dotenv | |
from langchain_core.output_parsers import XMLOutputParser | |
from langchain.prompts import ChatPromptTemplate | |
load_dotenv() | |
# suppress grpc and glog logs for gemini | |
os.environ["GRPC_VERBOSITY"] = "ERROR" | |
os.environ["GLOG_minloglevel"] = "2" | |
# RAG parameters | |
CHUNK_SIZE = 1024 | |
CHUNK_OVERLAP = CHUNK_SIZE // 8 | |
K = 10 | |
FETCH_K = 20 | |
llm_model_translation = { | |
"LLaMA 3": "llama3-70b-8192", | |
"OpenAI GPT 4o Mini": "gpt-4o-mini", | |
"OpenAI GPT 4o": "gpt-4o", | |
"OpenAI GPT 4": "gpt-4-turbo", | |
"Gemini 1.5 Pro": "gemini-1.5-pro", | |
"Claude Sonnet 3.5": "claude-3-5-sonnet-20240620", | |
} | |
llm_classes = { | |
"llama3-70b-8192": ChatGroq, | |
"gpt-4o-mini": ChatOpenAI, | |
"gpt-4o": ChatOpenAI, | |
"gpt-4-turbo": ChatOpenAI, | |
"gemini-1.5-pro": ChatGoogleGenerativeAI, | |
"claude-3-5-sonnet-20240620": ChatAnthropic, | |
} | |
xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, \ | |
fulfill all the requirements of the prompt and provide citations. If a part of the generated text does \ | |
not use any of the sources, don't put a citation for that part. Otherwise, list all sources used for that part of the text. | |
At the end of each relevant part, add a citation in square brackets, numbered sequentially starting from [0], regardless of the source's original ID. | |
Remember, you must return both the requested text and citations. A citation consists of a VERBATIM quote that \ | |
justifies the text and a sequential number (starting from 0) for the quote's article. Return a citation for every quote across all articles \ | |
that justify the text. Use the following format for your final output: | |
<cited_text> | |
<text></text> | |
<citations> | |
<citation><source_id></source_id><source></source><quote></quote></citation> | |
<citation><source_id></source_id><source></source><quote></quote></citation> | |
... | |
</citations> | |
</cited_text> | |
Here are the sources:{context}""" | |
xml_prompt = ChatPromptTemplate.from_messages( | |
[("system", xml_system), ("human", "{input}")] | |
) | |
def format_docs_xml(docs: list[Document]) -> str: | |
formatted = [] | |
for i, doc in enumerate(docs): | |
doc_str = f"""\ | |
<source> | |
<source>{doc.metadata['source']}</source> | |
<title>{doc.metadata['title']}</title> | |
<article_snippet>{doc.page_content}</article_snippet> | |
</source>""" | |
formatted.append(doc_str) | |
return "\n\n<sources>" + "\n".join(formatted) + "</sources>" | |
def citations_to_html(citations_data): | |
if citations_data: | |
html_output = "<ul>" | |
for index, citation in enumerate(citations_data): | |
source_id = citation['citation'][0]['source_id'] | |
source = citation['citation'][1]['source'] | |
quote = citation['citation'][2]['quote'] | |
html_output += f""" | |
<li> | |
[{index}] - "{source}" <br> | |
"{quote}" | |
</li> | |
""" | |
html_output += "</ul>" | |
return html_output | |
return "" | |
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048): | |
model_name = llm_model_translation.get(model) | |
llm_class = llm_classes.get(model_name) | |
if not llm_class: | |
raise ValueError(f"Model {model} not supported.") | |
try: | |
llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length) | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
llm = None | |
return llm | |
def create_db_with_langchain(path: list[str], url_content: dict): | |
all_docs = [] | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") | |
if path: | |
for file in path: | |
loader = PyMuPDFLoader(file) | |
data = loader.load() | |
# split it into chunks | |
docs = text_splitter.split_documents(data) | |
all_docs.extend(docs) | |
if url_content: | |
for url, content in url_content.items(): | |
doc = Document(page_content=content, metadata={"source": url}) | |
# split it into chunks | |
docs = text_splitter.split_documents([doc]) | |
all_docs.extend(docs) | |
# print docs | |
for idx, doc in enumerate(all_docs): | |
print(f"Doc: {idx} | Length = {len(doc.page_content)}") | |
assert len(all_docs) > 0, "No PDFs or scrapped data provided" | |
db = Chroma.from_documents(all_docs, embedding_function) | |
return db | |
def generate_rag( | |
prompt: str, | |
topic: str, | |
model: str, | |
url_content: dict, | |
path: list[str], | |
temperature: float = 1.0, | |
max_length: int = 2048, | |
api_key: str = "", | |
sys_message="", | |
): | |
llm = load_llm(model, api_key, temperature, max_length) | |
if llm is None: | |
print("Failed to load LLM. Aborting operation.") | |
return None | |
db = create_db_with_langchain(path, url_content) | |
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K}) | |
rag_prompt = hub.pull("rlm/rag-prompt") | |
def format_docs(docs): | |
if all(isinstance(doc, Document) for doc in docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
else: | |
raise TypeError("All items in docs must be instances of Document.") | |
docs = retriever.get_relevant_documents(topic) | |
# formatted_docs = format_docs(docs) | |
# rag_chain = ( | |
# {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser() | |
# ) | |
# return rag_chain.invoke(prompt) | |
formatted_docs = format_docs_xml(docs) | |
rag_chain = ( | |
RunnablePassthrough.assign(context=lambda _: formatted_docs) | |
| xml_prompt | |
| llm | |
| XMLOutputParser() | |
) | |
result = rag_chain.invoke({"input": prompt}) | |
from pprint import pprint | |
pprint(result) | |
return result['cited_text'][0]['text'], citations_to_html(result['cited_text'][1]['citations']) | |
def generate_base( | |
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message="" | |
): | |
llm = load_llm(model, api_key, temperature, max_length) | |
if llm is None: | |
print("Failed to load LLM. Aborting operation.") | |
return None | |
try: | |
output = llm.invoke(prompt).content | |
return output | |
except Exception as e: | |
print(f"An error occurred while running the model: {e}") | |
return None | |
def generate( | |
prompt: str, | |
topic: str, | |
model: str, | |
url_content: dict, | |
path: list[str], | |
temperature: float = 1.0, | |
max_length: int = 2048, | |
api_key: str = "", | |
sys_message="", | |
): | |
if path or url_content: | |
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message) | |
else: | |
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message) |