bstraehle commited on
Commit
5c38fee
·
verified ·
1 Parent(s): 90bb0ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import logging, os, sys, threading, time
3
+
4
+ from dotenv import load_dotenv, find_dotenv
5
+
6
+ from rag_langchain import LangChainRAG
7
+ from rag_llamaindex import LlamaIndexRAG
8
+ from trace import trace_wandb
9
+
10
+ lock = threading.Lock()
11
+
12
+ _ = load_dotenv(find_dotenv())
13
+
14
+ RAG_INGESTION = False # load, split, embed, and store documents
15
+
16
+ RAG_OFF = "Off"
17
+ RAG_LANGCHAIN = "LangChain"
18
+ RAG_LLAMAINDEX = "LlamaIndex"
19
+
20
+ config = {
21
+ "chunk_overlap": 100, # split documents
22
+ "chunk_size": 2000, # split documents
23
+ "k": 2, # retrieve documents
24
+ "model_name": "gpt-4-0314", # llm
25
+ "temperature": 0 # llm
26
+ }
27
+
28
+ logging.basicConfig(stream = sys.stdout, level = logging.INFO)
29
+ logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
30
+
31
+ def invoke(openai_api_key, prompt, rag_option):
32
+ if not openai_api_key:
33
+ raise gr.Error("OpenAI API Key is required.")
34
+ if not prompt:
35
+ raise gr.Error("Prompt is required.")
36
+ if not rag_option:
37
+ raise gr.Error("Retrieval-Augmented Generation is required.")
38
+
39
+ with lock:
40
+ os.environ["OPENAI_API_KEY"] = openai_api_key
41
+
42
+ if (RAG_INGESTION):
43
+ if (rag_option == RAG_LANGCHAIN):
44
+ rag = LangChainRAG()
45
+ rag.ingestion(config)
46
+ elif (rag_option == RAG_LLAMAINDEX):
47
+ rag = LlamaIndexRAG()
48
+ rag.ingestion(config)
49
+
50
+ completion = ""
51
+ result = ""
52
+ callback = ""
53
+ err_msg = ""
54
+
55
+ try:
56
+ start_time_ms = round(time.time() * 1000)
57
+
58
+ if (rag_option == RAG_LANGCHAIN):
59
+ rag = LangChainRAG()
60
+ completion, callback = rag.rag_chain(config, prompt)
61
+ result = completion["result"]
62
+ elif (rag_option == RAG_LLAMAINDEX):
63
+ rag = LlamaIndexRAG()
64
+ result, callback = rag.retrieval(config, prompt)
65
+ else:
66
+ rag = LangChainRAG()
67
+ completion, callback = rag.llm_chain(config, prompt)
68
+ result = completion.generations[0][0].text
69
+ except Exception as e:
70
+ err_msg = e
71
+
72
+ raise gr.Error(e)
73
+ finally:
74
+ end_time_ms = round(time.time() * 1000)
75
+
76
+ trace_wandb(
77
+ config,
78
+ rag_option,
79
+ prompt,
80
+ completion,
81
+ result,
82
+ callback,
83
+ err_msg,
84
+ start_time_ms,
85
+ end_time_ms
86
+ )
87
+
88
+ del os.environ["OPENAI_API_KEY"]
89
+
90
+ return result
91
+
92
+ gr.close_all()
93
+
94
+ demo = gr.Interface(
95
+ fn = invoke,
96
+ inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
97
+ gr.Textbox(label = "Prompt", value = "List GPT-4's exam scores and benchmark results.", lines = 1),
98
+ gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
99
+ outputs = [gr.Textbox(label = "Completion")],
100
+ title = "Context-Aware Reasoning Application",
101
+ description = os.environ["DESCRIPTION"],
102
+ examples = [["sk-<BringYourOwn>", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LLAMAINDEX],
103
+ ["sk-<BringYourOwn>", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
104
+ ["sk-<BringYourOwn>", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LLAMAINDEX],
105
+ ["sk-<BringYourOwn>", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
106
+ ["sk-<BringYourOwn>", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LLAMAINDEX]],
107
+ cache_examples = False
108
+ )
109
+
110
+ demo.launch()