File size: 4,741 Bytes
03ab966
3ede494
eceefb4
 
 
 
 
 
 
 
 
 
 
 
 
 
516ec1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ede494
 
 
 
 
 
 
 
03ab966
 
3ede494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6c63bf
03ab966
 
 
 
a6c63bf
 
3ede494
 
 
 
 
a6c63bf
 
3ede494
 
 
 
 
03ab966
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import openai, os

from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch

from pymongo import MongoClient

PDF_URL       = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL       = "https://openai.com/research/gpt-4"
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"

YOUTUBE_DIR = "/data/youtube"
CHROMA_DIR  = "/data/chroma"

MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
MONGODB_DB_NAME           = "langchain_db"
MONGODB_COLLECTION_NAME   = "gpt-4"
MONGODB_INDEX_NAME        = "default"

LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])

client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]

config = {
    "chunk_overlap": 150,
    "chunk_size": 1500,
    "k": 3,
    "model_name": "gpt-4-0613",
    "temperature": 0,
}

def document_loading_splitting():
    # Document loading
    docs = []
    
    # Load PDF
    loader = PyPDFLoader(PDF_URL)
    docs.extend(loader.load())
    
    # Load Web
    loader = WebBaseLoader(WEB_URL)
    docs.extend(loader.load())
    
    # Load YouTube
    loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
                                               YOUTUBE_URL_2,
                                               YOUTUBE_URL_3], YOUTUBE_DIR), 
                           OpenAIWhisperParser())
    docs.extend(loader.load())

    # Document splitting
    text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
                                                   chunk_size = config["chunk_size"])
    split_documents = text_splitter.split_documents(docs)
    
    return split_documents

def document_storage_chroma(documents):
    Chroma.from_documents(documents = documents, 
                          embedding = OpenAIEmbeddings(disallowed_special = ()), 
                          persist_directory = CHROMA_DIR)

def document_storage_mongodb(documents):
    MongoDBAtlasVectorSearch.from_documents(documents = documents,
                                            embedding = OpenAIEmbeddings(disallowed_special = ()),
                                            collection = collection,
                                            index_name = MONGODB_INDEX_NAME)

def document_retrieval_chroma(llm, prompt):
    return Chroma(embedding_function = OpenAIEmbeddings(),
                  persist_directory = CHROMA_DIR)

def document_retrieval_mongodb(llm, prompt):
    return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
                                                           MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
                                                           OpenAIEmbeddings(disallowed_special = ()),
                                                           index_name = MONGODB_INDEX_NAME)

def get_llm(openai_api_key):
    return ChatOpenAI(model_name = config["model_name"], 
                      openai_api_key = openai_api_key, 
                      temperature = config["temperature"])

def llm_chain(openai_api_key, prompt):
    llm_chain = LLMChain(llm = get_llm(openai_api_key), 
                         prompt = LLM_CHAIN_PROMPT, 
                         verbose = False)
    completion = llm_chain.generate([{"question": prompt}])
    return completion, llm_chain

def rag_chain(openai_api_key, prompt, db):
    rag_chain = RetrievalQA.from_chain_type(get_llm(openai_api_key), 
                                            chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT}, 
                                            retriever = db.as_retriever(search_kwargs = {"k": config["k"]}), 
                                            return_source_documents = True,
                                            verbose = False)
    completion = rag_chain({"query": prompt})
    return completion, rag_chain