import gradio as gr import os, time from dotenv import load_dotenv, find_dotenv from rag import get_llm, llm_chain, rag_chain from trace import wandb_trace _ = load_dotenv(find_dotenv()) RAG_OFF = "Off" RAG_CHROMA = "Chroma" RAG_MONGODB = "MongoDB" def invoke(openai_api_key, rag_option, prompt): if (openai_api_key == ""): raise gr.Error("OpenAI API Key is required.") if (rag_option is None): raise gr.Error("Retrieval Augmented Generation is required.") if (prompt == ""): raise gr.Error("Prompt is required.") chain = None completion = "" result = "" generation_info = "" llm_output = "" err_msg = "" try: start_time_ms = round(time.time() * 1000) if (rag_option == RAG_CHROMA): #splits = document_loading_splitting() #document_storage_chroma(splits) db = document_retrieval_chroma(openai_api_key, prompt) completion, chain = rag_chain(openai_api_key, prompt, db) result = completion["result"] elif (rag_option == RAG_MONGODB): #splits = document_loading_splitting() #document_storage_mongodb(splits) db = document_retrieval_mongodb(openai_api_key, prompt) completion, chain = rag_chain(openai_api_key, prompt, db) result = completion["result"] else: completion, chain = llm_chain(openai_api_key, prompt) if (completion.generations[0] != None and completion.generations[0][0] != None): result = completion.generations[0][0].text generation_info = completion.generations[0][0].generation_info llm_output = completion.llm_output except Exception as e: err_msg = e raise gr.Error(e) finally: end_time_ms = round(time.time() * 1000) wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms) return result gr.close_all() demo = gr.Interface(fn=invoke, inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF), gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1), ], outputs = [gr.Textbox(label = "Completion", lines = 1)], title = "Generative AI - LLM & RAG", description = os.environ["DESCRIPTION"]) demo.launch()