File size: 3,569 Bytes
7d6d701 7ddcfd9 7d6d701 99bbf81 a627434 ce136c7 004cf23 eb978fe 7ddcfd9 eb978fe 44a256c eb978fe 7ddcfd9 3ddc880 51605e1 5b9fc25 7ddcfd9 08b6d98 e0ddc02 f6fcf7f 917e125 ebcdcac 044c0a3 ebcdcac 044c0a3 917e125 7ddcfd9 7c151aa 44a256c ce136c7 1e517cc acf522c 26b6a5b ddfaa69 ce136c7 12d440a 1e517cc 1283168 9102fcd 99bbf81 e0ddc02 ce136c7 e0ddc02 ce136c7 e0ddc02 ce136c7 99bbf81 ddfaa69 1283168 12d440a 7ddcfd9 c2e6078 37ab520 043b829 99bbf81 9b93b70 2a2835e 9b93b70 ce136c7 9b93b70 44a256c 8d60a3f 7d6d701 7ddcfd9 1f3b512 14e92f6 99bbf81 e0ddc02 b7d5b27 c42c62b 4c13268 e0ddc02 4c13268 7ddcfd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import os, time
from dotenv import load_dotenv, find_dotenv
from rag_langchain import llm_chain, rag_chain, rag_ingestion_langchain
from rag_llamaindex import rag_ingestion_llamaindex, rag_retrieval
from trace import trace_wandb
_ = load_dotenv(find_dotenv())
RAG_INGESTION = False # load, split, embed, and store documents
config = {
"k": 3, # retrieve documents
"model_name": "gpt-4-0314", # llm
"temperature": 0 # llm
}
RAG_OFF = "Off"
RAG_LANGCHAIN = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"
def invoke(openai_api_key, prompt, rag_option):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
os.environ["OPENAI_API_KEY"] = openai_api_key
if (RAG_INGESTION):
if (rag_option == RAG_LANGCHAIN):
rag_ingestion_llangchain(config)
elif (rag_option == RAG_LLAMAINDEX):
rag_ingestion_llamaindex(config)
chain = None
completion = ""
result = ""
callback = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_LANGCHAIN):
completion, chain, callback = rag_chain(config, prompt)
result = completion["result"]
elif (rag_option == RAG_LLAMAINDEX):
result = rag_retrieval(config, prompt)
else:
completion, chain, callback = llm_chain(config, prompt)
if (completion.generations[0] != None and completion.generations[0][0] != None):
result = completion.generations[0][0].text
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
#trace_wandb(config,
# rag_option,
# prompt,
# completion,
# result,
# callback,
# err_msg,
# start_time_ms,
# end_time_ms)
return result
gr.close_all()
demo = gr.Interface(fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"],
examples = [["", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LANGCHAIN],
["", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
["", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LANGCHAIN],
["", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
["", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LANGCHAIN]],
cache_examples = False)
demo.launch() |