Spaces:
Sleeping
Sleeping
import os | |
import os.path | |
import serpapi | |
import requests | |
import streamlit as st | |
from typing import List | |
from docx import Document | |
from bs4 import BeautifulSoup | |
import huggingface_hub as hfh | |
import feedparser | |
from datasets import load_dataset | |
from urllib.parse import quote | |
from llama_index.llms.openai import OpenAI | |
from llama_index.core.schema import MetadataMode, NodeWithScore | |
from langchain_community.document_loaders import WebBaseLoader | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from langchain_community.document_loaders import PyPDFLoader | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.postprocessor.cohere_rerank import CohereRerank | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.core.query_engine.multistep_query_engine import MultiStepQueryEngine | |
from llama_index.core.indices.query.query_transform.base import StepDecomposeQueryTransform | |
from llama_index.core.node_parser import SemanticSplitterNodeParser | |
from llama_index.core.retrievers import VectorIndexRetriever, KeywordTableSimpleRetriever, BaseRetriever | |
from llama_index.core.postprocessor import MetadataReplacementPostProcessor, SimilarityPostprocessor | |
from llama_index.core import (VectorStoreIndex, SimpleDirectoryReader, ServiceContext, load_index_from_storage, | |
StorageContext, Document, Settings, SimpleKeywordTableIndex, | |
QueryBundle, get_response_synthesizer) | |
import warnings | |
warnings.filterwarnings("ignore") | |
st.session_state.cohere_api_key = None | |
st.session_state.serp_api_key = None | |
st.set_page_config( | |
page_title="My Streamlit App", | |
page_icon=":rocket:", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
def setting_api_key(openai_api_key, serp_api_key): | |
try: | |
os.environ['OPENAI_API_KEY'] = openai_api_key | |
st.session_state.hf_token = os.getenv("hf_token") | |
hfh.login(token=st.session_state.hf_token) | |
st.session_state.cohere_api_key = os.getenv("cohere_api_key") | |
st.session_state.serp_api_key = serp_api_key | |
except Exception as e: | |
st.warning(e) | |
def setup_llm_embed(): | |
template = """<|system|> | |
Mention Clearly Before response " RAG Output" | |
Please check if the following pieces of context has any mention of the keywords provided | |
in the question.Response as much as you could with context you get. | |
you are Question answering system based AI, Machine Learning , Deep Learning , Generative AI, Data | |
science and Data Analytics.if the following pieces of Context does not relate to Question, | |
You must not answer on your own,you don't know the answer. | |
</s> | |
<|user|> | |
Question:{query_str}</s> | |
<|assistant|> """ | |
llm = OpenAI(model="gpt-3.5-turbo-0125", | |
temperature=0.1, | |
model_kwargs={'trust_remote_code': True}, | |
max_tokens=512, | |
system_prompt=template) | |
# embed_model = OpenAIEmbedding(model="text-embedding-3-small") | |
# embed_model = OpenAIEmbedding() | |
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5") | |
return llm, embed_model | |
def semantic_split(embed_model, documents): | |
sentence_node_parser = SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=90, | |
embed_model=embed_model) | |
nodes = sentence_node_parser.get_nodes_from_documents(documents) | |
return nodes | |
def ctx_vector_func(llm, embed_model, nodes): | |
# Incorporate Embedding Model and LLM - memory | |
ctx_vector = ServiceContext.from_defaults( | |
llm=llm, | |
embed_model=embed_model, | |
node_parser=nodes) | |
return ctx_vector | |
def saving_vectors(vector_index, keyword_index): | |
vector_index.storage_context.persist(persist_dir="vectors/vector_index/") | |
keyword_index.storage_context.persist(persist_dir="vectors/keyword_index/") | |
def create_vector_and_keyword_index(nodes, ctx_vector): | |
vector_index = VectorStoreIndex(nodes, service_context=ctx_vector) | |
keyword_index = SimpleKeywordTableIndex(nodes, service_context=ctx_vector) | |
saving_vectors(vector_index, keyword_index) | |
return vector_index, keyword_index | |
class CustomRetriever(BaseRetriever): | |
def __init__( | |
self, | |
vector_retriever: VectorIndexRetriever, | |
keyword_retriever: KeywordTableSimpleRetriever, | |
mode: str = "AND", | |
) -> None: | |
self._vector_retriever = vector_retriever | |
self._keyword_retriever = keyword_retriever | |
if mode not in ("AND", "OR"): | |
raise ValueError("Invalid mode.") | |
self._mode = mode | |
super().__init__() | |
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
vector_nodes = self._vector_retriever.retrieve(query_bundle) | |
keyword_nodes = self._keyword_retriever.retrieve(query_bundle) | |
vector_ids = {n.node.node_id for n in vector_nodes} | |
keyword_ids = {n.node.node_id for n in keyword_nodes} | |
combined_dict = {n.node.node_id: n for n in vector_nodes} | |
combined_dict.update({n.node.node_id: n for n in keyword_nodes}) | |
if self._mode == "AND": | |
retrieve_ids = vector_ids.intersection(keyword_ids) | |
else: | |
retrieve_ids = vector_ids.union(keyword_ids) | |
retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids] | |
return retrieve_nodes | |
def search_arxiv(query, max_results=8): | |
encoded_query = quote(query) | |
base_url = 'http://export.arxiv.org/api/query?' | |
query_url = f'{base_url}search_query={encoded_query}&start=0&max_results={max_results}' | |
feed = feedparser.parse(query_url) | |
papers = [] | |
for entry in feed.entries: | |
paper_info = { | |
'Title': entry.title, | |
'URL': entry.link | |
} | |
papers.append(paper_info) | |
return papers | |
def remove_empty_lines(lines): | |
non_empty_lines = [line for line in lines if line.strip()] | |
return ' '.join(non_empty_lines) | |
def get_article_and_arxiv_content(query): | |
# Article content | |
serpapi_api_key = st.session_state.serp_api_key | |
search_engine = "google" # bing | |
params = { | |
"engine": "google", | |
"gl": "us", | |
"hl": "en", | |
"api_key": serpapi_api_key, | |
"q": query | |
} | |
serpapi_wrapper = serpapi.GoogleSearch(params) | |
search_results = serpapi_wrapper.get_dict() | |
results = [] | |
for result_type in ["organic_results", "related_questions"]: | |
if result_type in search_results: | |
for result in search_results[result_type]: | |
if "title" in result and "link" in result: | |
# Extract title and link | |
item = {"title": result["title"], "link": result["link"]} | |
results.append(item) | |
# Store Each article links in List | |
links = [result['link'] for result in results] | |
titles = [result['title'] for result in results] | |
contents = [] | |
i = 0 | |
for link, title in zip(links, titles): | |
response = requests.get(link) | |
soup = BeautifulSoup(response.content, "html.parser") | |
content_tags = soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']) | |
document = "" | |
for tag in content_tags: | |
document += tag.text + "\n" | |
if not document: | |
loader = WebBaseLoader(link) | |
document_ = loader.load() | |
document = document_[0].page_content | |
i += 1 | |
if i == 4: | |
break | |
article = remove_empty_lines(document.split('\n')) # | |
contents.append(article) | |
base_url = "http://export.arxiv.org/api/query" | |
papers_to_download = search_arxiv(query) | |
papers_urls = [] | |
for paper in papers_to_download: | |
page_url = paper['URL'] | |
response = requests.get(page_url) | |
soup = BeautifulSoup(response.content, "html.parser") | |
download_link = soup.find("a", class_="abs-button download-pdf") | |
if download_link: | |
pdf_url = download_link['href'] | |
if not pdf_url.startswith("http"): | |
pdf_url = "https://arxiv.org" + pdf_url | |
papers_urls.append(pdf_url) | |
paper_content = [] | |
for url_ in papers_urls[:2]: | |
loader = PyPDFLoader(url_) | |
pages = loader.load_and_split() | |
paper_text = '' | |
for page in pages: | |
page_text = remove_empty_lines(page.page_content.split('\n')) | |
paper_text += page_text | |
if paper_text: | |
paper_content.append(paper_text) | |
return contents + paper_content | |
# Uploading Locally Generated Index | |
def creating_vector_path(): | |
PERSIST_DIR_vector = "vectors/vector_index" | |
PERSIST_DIR_keyword = "vectors/keyword_index" | |
if not os.path.exists(PERSIST_DIR_vector): | |
os.makedirs(PERSIST_DIR_vector) | |
if not os.path.exists(PERSIST_DIR_keyword): | |
os.makedirs(PERSIST_DIR_keyword) | |
return PERSIST_DIR_vector, PERSIST_DIR_keyword | |
def load_vector_index(PERSIST_DIR_vector, PERSIST_DIR_keyword): | |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR_vector) | |
vector_index = load_index_from_storage(storage_context) | |
storage_context_ = StorageContext.from_defaults(persist_dir=PERSIST_DIR_keyword) | |
keyword_index = load_index_from_storage(storage_context_) | |
return vector_index,keyword_index | |
def response_generation(query, cohere_api_key, vector_index, keyword_index): | |
cohere_rerank = CohereRerank(api_key=cohere_api_key, top_n=4) | |
postprocessor = SimilarityPostprocessor(similarity_cutoff=0.85) # default 0.80 | |
sentence_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=8) | |
keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index, similarity_top_k=8) | |
custom_retriever = CustomRetriever(sentence_retriever, keyword_retriever) | |
response_synthesizer = get_response_synthesizer() | |
query_engine = RetrieverQueryEngine(retriever=custom_retriever, response_synthesizer=response_synthesizer, | |
node_postprocessors=[ | |
MetadataReplacementPostProcessor(target_metadata_key="window"), | |
cohere_rerank, postprocessor]) | |
# step_decompose_transform = StepDecomposeQueryTransform(llm, verbose=False) | |
# query_engine = MultiStepQueryEngine(query_engine = query_engine, query_transform=step_decompose_transform ) | |
response = query_engine.query(query) | |
return response | |
def stream_output(response): | |
st.write("""<h1 style="font-size: 20px;">Output From RAG </h1>""", unsafe_allow_html=True) | |
for char in response: | |
st.text(char) | |
def func_add_new_article_content(content_): | |
documents = [Document(text=t) for t in content_] | |
# LLM and Embedding Model Setup | |
llm, embed_model = setup_llm_embed() | |
Settings.llm = llm | |
Settings.embed_model = embed_model | |
# Splitting Nodes | |
new_nodes = semantic_split(embed_model, documents) | |
ctx_vector = ctx_vector_func(llm, embed_model, new_nodes) # documents - nodes | |
new_vector_index, new_keyword_index = create_vector_and_keyword_index(new_nodes, ctx_vector) # documents - nodes | |
return new_vector_index, new_keyword_index, new_nodes | |
def updating_vector(new_nodes, vector_index, keyword_index): | |
vector_index.insert_nodes(new_nodes) | |
keyword_index.insert_nodes(new_nodes) | |
saving_vectors(vector_index, keyword_index) | |
def main(): | |
st.write("""<h1 style="font-size: 30px;">GenAI Question-Answer System Utilizing Advanced Retrieval-Augmented | |
Generation π§</h1>""", unsafe_allow_html=True) | |
st.markdown("""This application operates on a paid source model and framework to ensure high accuracy and minimize | |
hallucination. Prior to running the application, it's necessary to configure two keys. Learn more about | |
these keys and how to generate them below.""") | |
if 'key_flag' not in st.session_state: | |
st.session_state.key_flag = False | |
col_left, col_right = st.columns([1, 2]) | |
with (col_left): | |
st.write("""<h1 style="font-size: 15px;">Enter your OpenAI API key </h1>""", unsafe_allow_html=True) | |
openai_api_key = st.text_input(placeholder="OpenAI api key ", label=" ", type="password") | |
st.write("""<h1 style="font-size: 15px;">Enter your SERP API key </h1>""", unsafe_allow_html=True) | |
serp_api_key = st.text_input(placeholder="Serp api key ", label=" ", type="password") | |
set_keys_button = st.button("Set Keys ", type="primary") | |
key_flag = False | |
try: | |
if set_keys_button and openai_api_key and serp_api_key: | |
setting_api_key(openai_api_key, serp_api_key) | |
st.success("Successful π") | |
st.session_state.key_flag = True | |
elif set_keys_button: | |
st.warning("Please set the necessary API keys !") | |
except Exception as e: | |
st.warning(e) | |
with col_right: | |
st.write("""<h1 style="font-size: 15px;">Enter your Question </h1>""", unsafe_allow_html=True) | |
query = st.text_input(placeholder="Query ", label=" ", max_chars=192) | |
generate_response_button = st.button("Generate response", type="primary") | |
if generate_response_button and st.session_state.key_flag and str(query): | |
vector_path, keyword_path = creating_vector_path() | |
vector_index, keyword_index = load_vector_index(vector_path, keyword_path) | |
response = response_generation(query, st.session_state.cohere_api_key, vector_index, keyword_index) | |
if response in ["Empty Response", "RAG Output"] or not response: | |
with st.spinner("Getting Information from Articles, It will take some time."): | |
content_ = get_article_and_arxiv_content(query) | |
new_vector_index, new_keyword_index, new_nodes = func_add_new_article_content(content_) | |
response = response_generation(query, st.session_state.cohere_api_key, new_vector_index, new_keyword_index) | |
stream_output(response) | |
col1, col2 = st.columns([1, 10]) | |
thumps_up_button = col1.button("π") | |
thumps_down_button = col2.button("π") | |
if thumps_up_button: | |
st.write("Thank you for your positive feedback!") | |
updating_vector(new_nodes, vector_index, keyword_index) | |
if thumps_down_button: | |
st.write("""We're sorry , We will improve it.""") | |
elif response: | |
stream_output(response) | |
col1, col2 = st.columns([1, 10]) | |
if col1.button("π"): | |
st.write("Thank you for your positive feedback!") | |
if col2.button("π"): | |
st.write("We're sorry , We will improve it.") | |
elif generate_response_button and not str(query) and not st.session_state.key_flag: | |
st.warning("Please set the necessary API keys and Enter the query") | |
elif generate_response_button and str(query) and not st.session_state.key_flag: | |
st.warning("Please set the necessary API keys----") | |
elif generate_response_button and st.session_state.key_flag and not str(query): | |
st.warning("Please Enter the query !") | |
if __name__ == "__main__": | |
main() | |