File size: 9,002 Bytes
7d6d701
64931b6
7d6d701
0a1cd5f
f4087b0
55274da
f4087b0
 
 
55274da
f4087b0
 
 
 
a426126
1ad0dcf
bf1b617
 
a627434
 
7d6d701
 
 
8cf750e
 
4d8a63e
 
 
 
 
53d588f
 
4a15de2
 
38ee3ac
84e076a
 
 
38ee3ac
 
 
 
84ce8cd
6f02f68
b7d2e54
 
e38fd6d
33156e9
 
b610816
cd9c510
 
6553dbd
55274da
6772176
2db1016
 
 
994b8cd
bf1b617
86d2f65
 
 
 
 
 
7e2b6ca
86d2f65
 
b1e2693
 
 
 
 
86d2f65
84e076a
 
86d2f65
bf1b617
 
 
53d588f
 
 
86d2f65
bf1b617
53d588f
 
 
 
86d2f65
f5190b5
503e34f
53d588f
503e34f
9549818
33f1a4f
4d8a63e
 
 
 
503e34f
 
 
64931b6
38ee3ac
1dcde2f
503e34f
 
33f1a4f
 
84e076a
542a800
64931b6
38ee3ac
1dcde2f
33f1a4f
1dcde2f
64931b6
 
 
6cb1c29
1112ac1
1c0e381
f070985
1c0e381
979e095
 
 
6cb1c29
1dcde2f
385fd0d
9c98406
 
6cb1c29
dadebb4
 
 
6cb1c29
 
 
7041a32
 
73da155
8700249
 
 
7d18057
043b829
542a800
6cb1c29
09d8d95
6cb1c29
 
86d2f65
ebcdcac
044c0a3
86d2f65
044c0a3
ebcdcac
044c0a3
26b6a5b
043b829
1dcde2f
26b6a5b
1283168
9102fcd
bdac2cb
1283168
f7926b9
86d2f65
53d588f
bf1b617
503e34f
1dcde2f
8d60a3f
86d2f65
db5f00f
 
503e34f
1dcde2f
8d60a3f
1283168
1dcde2f
2dd274c
1283168
26b6a5b
c2e6078
37ab520
043b829
1dcde2f
8d60a3f
7d6d701
 
 
1cb182c
f01c51b
3c3eb7e
b7d5b27
908ded3
7d6d701
a4da0c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import gradio as gr
import openai, os, time, wandb

from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser

from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch

from pymongo import MongoClient

from wandb.sdk.data_types.trace_tree import Trace

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

WANDB_API_KEY = os.environ["WANDB_API_KEY"]

MONGODB_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
client = MongoClient(MONGODB_URI)
MONGODB_DB_NAME = "langchain_db"
MONGODB_COLLECTION_NAME = "gpt-4"
MONGODB_COLLECTION = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
MONGODB_INDEX_NAME = "default"

description = os.environ["DESCRIPTION"]

config = {
    "chunk_overlap": 150,
    "chunk_size": 1500,
    "k": 3,
    "model": "gpt-4",
    "temperature": 0,
}

template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say "Thanks for using the 🧠 app - Bernd" at the end of the answer. """

llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
rag_template = "Use the following pieces of context to answer the question at the end. " + template + "{context}. Question: {question} Helpful Answer: "

LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = llm_template)
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = rag_template)

CHROMA_DIR  = "/data/chroma"
YOUTUBE_DIR = "/data/youtube"

PDF_URL       = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL       = "https://openai.com/research/gpt-4"
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"

def document_loading_splitting():
    # Document loading
    docs = []
    # Load PDF
    loader = PyPDFLoader(PDF_URL)
    docs.extend(loader.load())
    # Load Web
    loader = WebBaseLoader(WEB_URL)
    docs.extend(loader.load())
    # Load YouTube
    loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
                                               YOUTUBE_URL_2,
                                               YOUTUBE_URL_3], YOUTUBE_DIR), 
                           OpenAIWhisperParser())
    docs.extend(loader.load())
    # Document splitting
    text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
                                                   chunk_size = config["chunk_size"])
    splits = text_splitter.split_documents(docs)
    return splits

def document_storage_chroma(splits):
    Chroma.from_documents(documents = splits, 
                          embedding = OpenAIEmbeddings(disallowed_special = ()), 
                          persist_directory = CHROMA_DIR)

def document_storage_mongodb(splits):
    MongoDBAtlasVectorSearch.from_documents(documents = splits,
                                            embedding = OpenAIEmbeddings(disallowed_special = ()),
                                            collection = MONGODB_COLLECTION,
                                            index_name = MONGODB_INDEX_NAME)

def document_retrieval_chroma(llm, prompt):
    db = Chroma(embedding_function = OpenAIEmbeddings(),
                persist_directory = CHROMA_DIR)
    return db

def document_retrieval_mongodb(llm, prompt):
    db = MongoDBAtlasVectorSearch.from_connection_string(MONGODB_URI,
                                                         MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
                                                         OpenAIEmbeddings(disallowed_special = ()),
                                                         index_name = MONGODB_INDEX_NAME)
    return db

def llm_chain(llm, prompt):
    llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT, verbose = False)
    completion = llm_chain.run({"question": prompt})
    return completion, llm_chain

def rag_chain(llm, prompt, db):
    rag_chain = RetrievalQA.from_chain_type(llm, 
                                            chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT}, 
                                            retriever = db.as_retriever(search_kwargs = {"k": config["k"]}), 
                                            return_source_documents = True,
                                            verbose = False)
    completion = rag_chain({"query": prompt})
    return completion, rag_chain

def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
    print(chain.inputKey)
    print(chain.outputKey)
    print(chain.retriever)
    wandb.init(project = "openai-llm-rag")
    if (rag_option == "Off" or str(status_msg) != ""):
        result = completion
    else:
        result = completion["result"]
        document_0 = completion["source_documents"][0]
        document_1 = completion["source_documents"][1]
        document_2 = completion["source_documents"][2]
    trace = Trace(
        kind = "chain",
        name = "LLMChain" if (rag_option == "Off") else "RetrievalQA",
        status_code = "SUCCESS" if (str(status_msg) == "") else "ERROR",
        status_message = str(status_msg),
        metadata={
            "chunk_overlap": "" if (rag_option == "Off") else config["chunk_overlap"],
            "chunk_size": "" if (rag_option == "Off") else config["chunk_size"],
            "k": "" if (rag_option == "Off") else config["k"],
            "model": config["model"],
            "temperature": config["temperature"],
        },
        inputs = {"rag_option": rag_option if (str(status_msg) == "") else "",
                  "prompt": str(prompt if (str(status_msg) == "") else ""), 
                  "prompt_template": str((llm_template if (rag_option == "Off") else rag_template) if (str(status_msg) == "") else ""),
                  "document_0": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_0),
                  "document_1": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_1),
                  "document_2": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_2)},
        outputs = {"result": result},
        start_time_ms = start_time_ms,
        end_time_ms = end_time_ms
    )
    trace.log("test")
    wandb.finish()

def invoke(openai_api_key, rag_option, prompt):
    if (openai_api_key == ""):
        raise gr.Error("OpenAI API Key is required.")
    if (rag_option is None):
        raise gr.Error("Retrieval Augmented Generation is required.")
    if (prompt == ""):
        raise gr.Error("Prompt is required.")
    completion = ""
    result = ""
    chain = ""
    status_msg = ""
    try:
        start_time_ms = round(time.time() * 1000)
        llm = ChatOpenAI(model_name = config["model"], 
                         openai_api_key = openai_api_key, 
                         temperature = config["temperature"])
        if (rag_option == "Chroma"):
            #splits = document_loading_splitting()
            #document_storage_chroma(splits)
            db = document_retrieval_chroma(llm, prompt)
            completion, chain = rag_chain(llm, prompt, db)
            result = completion["result"]
        elif (rag_option == "MongoDB"):
            #splits = document_loading_splitting()
            #document_storage_mongodb(splits)
            db = document_retrieval_mongodb(llm, prompt)
            completion, chain = rag_chain(llm, prompt, db)
            result = completion["result"]
        else:
            result, chain = llm_chain(llm, prompt)
            completion = result
    except Exception as e:
        status_msg = e
        raise gr.Error(e)
    finally:
        end_time_ms = round(time.time() * 1000)
        wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms)
    return result

gr.close_all()
demo = gr.Interface(fn=invoke, 
                    inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1), 
                              gr.Radio(["Off", "Chroma", "MongoDB"], label="Retrieval Augmented Generation", value = "Off"),
                              gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
                    outputs = [gr.Textbox(label = "Completion", lines = 1)],
                    title = "Generative AI - LLM & RAG",
                    description = description)
demo.launch()