import gradio as gr import openai, os, time, wandb from dotenv import load_dotenv, find_dotenv from langchain.chains import LLMChain, RetrievalQA from langchain.chat_models import ChatOpenAI from langchain.document_loaders import PyPDFLoader, WebBaseLoader from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader from langchain.document_loaders.generic import GenericLoader from langchain.document_loaders.parsers import OpenAIWhisperParser from langchain.embeddings.openai import OpenAIEmbeddings from langchain.prompts import PromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.vectorstores import MongoDBAtlasVectorSearch from pymongo import MongoClient from trace import wandb_trace #from wandb.sdk.data_types.trace_tree import Trace _ = load_dotenv(find_dotenv()) PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf" WEB_URL = "https://openai.com/research/gpt-4" YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE" YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE" YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ" YOUTUBE_DIR = "/data/youtube" CHROMA_DIR = "/data/chroma" MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"] MONGODB_DB_NAME = "langchain_db" MONGODB_COLLECTION_NAME = "gpt-4" MONGODB_INDEX_NAME = "default" LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"]) RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"]) #WANDB_API_KEY = os.environ["WANDB_API_KEY"] RAG_OFF = "Off" RAG_CHROMA = "Chroma" RAG_MONGODB = "MongoDB" client = MongoClient(MONGODB_ATLAS_CLUSTER_URI) collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME] config = { "chunk_overlap": 150, "chunk_size": 1500, "k": 3, "model_name": "gpt-4-0613", "temperature": 0, } def document_loading_splitting(): # Document loading docs = [] # Load PDF loader = PyPDFLoader(PDF_URL) docs.extend(loader.load()) # Load Web loader = WebBaseLoader(WEB_URL) docs.extend(loader.load()) # Load YouTube loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1, YOUTUBE_URL_2, YOUTUBE_URL_3], YOUTUBE_DIR), OpenAIWhisperParser()) docs.extend(loader.load()) # Document splitting text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"], chunk_size = config["chunk_size"]) split_documents = text_splitter.split_documents(docs) return split_documents def document_storage_chroma(documents): Chroma.from_documents(documents = documents, embedding = OpenAIEmbeddings(disallowed_special = ()), persist_directory = CHROMA_DIR) def document_storage_mongodb(documents): MongoDBAtlasVectorSearch.from_documents(documents = documents, embedding = OpenAIEmbeddings(disallowed_special = ()), collection = collection, index_name = MONGODB_INDEX_NAME) def document_retrieval_chroma(llm, prompt): return Chroma(embedding_function = OpenAIEmbeddings(), persist_directory = CHROMA_DIR) def document_retrieval_mongodb(llm, prompt): return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI, MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME, OpenAIEmbeddings(disallowed_special = ()), index_name = MONGODB_INDEX_NAME) def llm_chain(llm, prompt): llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT, verbose = False) completion = llm_chain.generate([{"question": prompt}]) return completion, llm_chain def rag_chain(llm, prompt, db): rag_chain = RetrievalQA.from_chain_type(llm, chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT}, retriever = db.as_retriever(search_kwargs = {"k": config["k"]}), return_source_documents = True, verbose = False) completion = rag_chain({"query": prompt}) return completion, rag_chain #def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms): # wandb.init(project = "openai-llm-rag") # # trace = Trace( # kind = "chain", # name = "" if (chain == None) else type(chain).__name__, # status_code = "success" if (str(err_msg) == "") else "error", # status_message = str(err_msg), # metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"], # "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"], # } if (str(err_msg) == "") else {}, # inputs = {"rag_option": rag_option, # "prompt": prompt, # "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else # str(chain.combine_documents_chain.llm_chain.prompt)), # "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]), # } if (str(err_msg) == "") else {}, # outputs = {"result": result, # "generation_info": str(generation_info), # "llm_output": str(llm_output), # "completion": str(completion), # } if (str(err_msg) == "") else {}, # model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else # str(chain.combine_documents_chain.llm_chain.llm.client)), # "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else # str(chain.combine_documents_chain.llm_chain.llm.model_name)), # "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else # str(chain.combine_documents_chain.llm_chain.llm.temperature)), # "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)), # } if (str(err_msg) == "") else {}, # start_time_ms = start_time_ms, # end_time_ms = end_time_ms # ) # # trace.log("evaluation") # wandb.finish() def invoke(openai_api_key, rag_option, prompt): if (openai_api_key == ""): raise gr.Error("OpenAI API Key is required.") if (rag_option is None): raise gr.Error("Retrieval Augmented Generation is required.") if (prompt == ""): raise gr.Error("Prompt is required.") chain = None completion = "" result = "" generation_info = "" llm_output = "" err_msg = "" try: start_time_ms = round(time.time() * 1000) llm = ChatOpenAI(model_name = config["model_name"], openai_api_key = openai_api_key, temperature = config["temperature"]) if (rag_option == RAG_CHROMA): #splits = document_loading_splitting() #document_storage_chroma(splits) db = document_retrieval_chroma(llm, prompt) completion, chain = rag_chain(llm, prompt, db) result = completion["result"] elif (rag_option == RAG_MONGODB): #splits = document_loading_splitting() #document_storage_mongodb(splits) db = document_retrieval_mongodb(llm, prompt) completion, chain = rag_chain(llm, prompt, db) result = completion["result"] else: completion, chain = llm_chain(llm, prompt) if (completion.generations[0] != None and completion.generations[0][0] != None): result = completion.generations[0][0].text generation_info = completion.generations[0][0].generation_info llm_output = completion.llm_output except Exception as e: err_msg = e raise gr.Error(e) finally: end_time_ms = round(time.time() * 1000) wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms) return result gr.close_all() demo = gr.Interface(fn=invoke, inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF), gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1), ], outputs = [gr.Textbox(label = "Completion", lines = 1)], title = "Generative AI - LLM & RAG", description = os.environ["DESCRIPTION"]) demo.launch()