rag / app.py
bstraehle's picture
Update app.py
8ded9c8
raw
history blame
No virus
3.54 kB
import gradio as gr
import logging, os, sys, time
from dotenv import load_dotenv, find_dotenv
from rag_langchain import LangChainRAG
from rag_llamaindex import LlamaIndexRAG
from trace import trace_wandb
_ = load_dotenv(find_dotenv())
RAG_INGESTION = False # load, split, embed, and store documents
RAG_OFF = "Off"
RAG_LANGCHAIN = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"
config = {
"chunk_overlap": 100, # split documents
"chunk_size": 2000, # split documents
"k": 3, # retrieve documents
"model_name": "gpt-4-0314", # llm
"temperature": 0 # llm
}
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
def invoke(openai_api_key, prompt, rag_option):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
os.environ["OPENAI_API_KEY"] = openai_api_key
if (RAG_INGESTION):
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
rag.ingestion(config)
elif (rag_option == RAG_LLAMAINDEX):
rag = LlamaIndexRAG()
rag.ingestion(config)
completion = ""
result = ""
callback = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
completion, callback = rag.rag_chain(config, prompt)
result = completion["result"]
elif (rag_option == RAG_LLAMAINDEX):
rag = LlamaIndexRAG()
result, callback = rag.retrieval(config, prompt)
else:
rag = LangChainRAG()
completion, callback = rag.llm_chain(config, prompt)
result = completion.generations[0][0].text
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
trace_wandb(
config,
rag_option,
prompt,
completion,
result,
callback,
err_msg,
start_time_ms,
end_time_ms
)
return result
gr.close_all()
demo = gr.Interface(
fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"],
examples = [["", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LLAMAINDEX],
["", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
["", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LLAMAINDEX],
["", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
["", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LLAMAINDEX]],
cache_examples = False
)
demo.launch()