|
import openai, os |
|
|
|
from langchain.chains import LLMChain, RetrievalQA |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.document_loaders import PyPDFLoader, WebBaseLoader |
|
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader |
|
from langchain.document_loaders.generic import GenericLoader |
|
from langchain.document_loaders.parsers import OpenAIWhisperParser |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.prompts import PromptTemplate |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.vectorstores import MongoDBAtlasVectorSearch |
|
|
|
from pymongo import MongoClient |
|
|
|
RAG_CHROMA = "Chroma" |
|
RAG_MONGODB = "MongoDB" |
|
|
|
PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf" |
|
WEB_URL = "https://openai.com/research/gpt-4" |
|
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE" |
|
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE" |
|
|
|
YOUTUBE_DIR = "/data/youtube" |
|
CHROMA_DIR = "/data/chroma" |
|
|
|
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"] |
|
MONGODB_DB_NAME = "langchain_db" |
|
MONGODB_COLLECTION_NAME = "gpt-4" |
|
MONGODB_INDEX_NAME = "default" |
|
|
|
LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"]) |
|
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"]) |
|
|
|
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI) |
|
collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME] |
|
|
|
def document_loading(): |
|
docs = [] |
|
|
|
|
|
loader = PyPDFLoader(PDF_URL) |
|
docs.extend(loader.load()) |
|
|
|
|
|
loader = WebBaseLoader(WEB_URL) |
|
docs.extend(loader.load()) |
|
|
|
|
|
loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1, YOUTUBE_URL_2], YOUTUBE_DIR), |
|
OpenAIWhisperParser()) |
|
docs.extend(loader.load()) |
|
|
|
return docs |
|
|
|
def document_splitting(config, docs): |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"], |
|
chunk_size = config["chunk_size"]) |
|
|
|
return text_splitter.split_documents(docs) |
|
|
|
def document_storage_chroma(chunks): |
|
Chroma.from_documents(documents = chunks, |
|
embedding = OpenAIEmbeddings(disallowed_special = ()), |
|
persist_directory = CHROMA_DIR) |
|
|
|
def document_storage_mongodb(chunks): |
|
MongoDBAtlasVectorSearch.from_documents(documents = chunks, |
|
embedding = OpenAIEmbeddings(disallowed_special = ()), |
|
collection = collection, |
|
index_name = MONGODB_INDEX_NAME) |
|
|
|
def document_retrieval_chroma(): |
|
return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()), |
|
persist_directory = CHROMA_DIR) |
|
|
|
def document_retrieval_mongodb(): |
|
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI, |
|
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME, |
|
OpenAIEmbeddings(disallowed_special = ()), |
|
index_name = MONGODB_INDEX_NAME) |
|
|
|
def rag_batch(config): |
|
docs = document_loading() |
|
|
|
chunks = document_splitting(config, docs) |
|
|
|
document_storage_chroma(chunks) |
|
document_storage_mongodb(chunks) |
|
|
|
def get_llm(config, openai_api_key): |
|
return ChatOpenAI(model_name = config["model_name"], |
|
openai_api_key = openai_api_key, |
|
temperature = config["temperature"]) |
|
|
|
def llm_chain(config, openai_api_key, prompt): |
|
llm_chain = LLMChain(llm = get_llm(config, openai_api_key), |
|
prompt = LLM_CHAIN_PROMPT, |
|
verbose = False) |
|
|
|
completion = llm_chain.generate([{"question": prompt}]) |
|
|
|
return completion, llm_chain |
|
|
|
def rag_chain(config, openai_api_key, rag_option, prompt): |
|
llm = get_llm(config, openai_api_key) |
|
|
|
if (rag_option == RAG_CHROMA): |
|
db = document_retrieval_chroma() |
|
elif (rag_option == RAG_MONGODB): |
|
db = document_retrieval_mongodb() |
|
|
|
rag_chain = RetrievalQA.from_chain_type(llm, |
|
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT}, |
|
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}), |
|
return_source_documents = True, |
|
verbose = False) |
|
|
|
completion = rag_chain({"query": prompt}, include_run_info = True) |
|
print("###" + str(completion)) |
|
|
|
return completion, rag_chain |