File size: 3,881 Bytes
5c38fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import logging, os, sys, threading, time

from dotenv import load_dotenv, find_dotenv

from rag_langchain import LangChainRAG
from rag_llamaindex import LlamaIndexRAG
from trace import trace_wandb

lock = threading.Lock()

_ = load_dotenv(find_dotenv())

RAG_INGESTION = False # load, split, embed, and store documents

RAG_OFF        = "Off"
RAG_LANGCHAIN  = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"

config = {
    "chunk_overlap": 100,       # split documents
    "chunk_size": 2000,         # split documents
    "k": 2,                     # retrieve documents
    "model_name": "gpt-4-0314", # llm
    "temperature": 0            # llm
}

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))

def invoke(openai_api_key, prompt, rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        os.environ["OPENAI_API_KEY"] = openai_api_key
        
        if (RAG_INGESTION):
            if (rag_option == RAG_LANGCHAIN):
                rag = LangChainRAG()
                rag.ingestion(config)
            elif (rag_option == RAG_LLAMAINDEX):
                rag = LlamaIndexRAG()
                rag.ingestion(config)
    
        completion = ""
        result = ""
        callback = ""
        err_msg = ""
        
        try:
            start_time_ms = round(time.time() * 1000)
    
            if (rag_option == RAG_LANGCHAIN):
                rag = LangChainRAG()
                completion, callback = rag.rag_chain(config, prompt)
                result = completion["result"]
            elif (rag_option == RAG_LLAMAINDEX):
                rag = LlamaIndexRAG()
                result, callback = rag.retrieval(config, prompt)
            else:
                rag = LangChainRAG()
                completion, callback = rag.llm_chain(config, prompt)
                result = completion.generations[0][0].text
        except Exception as e:
            err_msg = e
    
            raise gr.Error(e)
        finally:
            end_time_ms = round(time.time() * 1000)
            
            trace_wandb(
                config,
                rag_option,
                prompt, 
                completion, 
                result, 
                callback, 
                err_msg, 
                start_time_ms, 
                end_time_ms
            )

            del os.environ["OPENAI_API_KEY"]
    
        return result

gr.close_all()

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = "List GPT-4's exam scores and benchmark results.", lines = 1),
              gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
    outputs = [gr.Textbox(label = "Completion")],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"],
    examples = [["sk-<BringYourOwn>", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LLAMAINDEX],
                ["sk-<BringYourOwn>", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LLAMAINDEX],
                ["sk-<BringYourOwn>", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LLAMAINDEX]],
               cache_examples = False
)

demo.launch()