File size: 2,316 Bytes
60d2583
 
5c38fee
daa80cf
5c38fee
4b2f569
5ef932e
5c38fee
 
12178fd
5c38fee
c26729e
 
9e5685a
5c38fee
 
 
5ef932e
e5984fd
 
 
 
 
5c38fee
 
 
 
 
 
 
 
5ef932e
838b33c
 
82fd045
0835210
e5984fd
 
 
 
 
 
 
 
 
6fd8de0
 
 
 
5c38fee
 
 
 
 
 
2d67e97
e5984fd
 
2d67e97
 
5c38fee
a086fab
5c38fee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#TODO: Pydantic, image embedding, clean up data set

import gradio as gr
import logging, os, sys, threading

from custom_utils import connect_to_database, rag_ingestion, rag_retrieval, rag_inference

lock = threading.Lock()

RAG_INGESTION = False

RAG_OFF      = "Off"
RAG_NAIVE    = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
    
def invoke(openai_api_key, 
           prompt, 
           accomodates, 
           bedrooms, 
           rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        db, collection = connect_to_database()

        if (RAG_INGESTION):
            return rag_ingestion(collection)
        else:
            retrieval_result = rag_retrieval(openai_api_key, 
                                             prompt, 
                                             accomodates,
                                             bedrooms,
                                             db, 
                                             collection)
            inference_result = rag_inference(openai_api_key, 
                                             prompt, 
                                             retrieval_result)
            print("###")
            print(inference_result)
            print("###")
            return inference_result

gr.close_all()

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = os.environ["PROMPT"], lines = 1),
              gr.Number(label = "Accomodates", value = 2),
              gr.Number(label = "Bedrooms", value = 1),
              gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_NAIVE)],
    outputs = [gr.Markdown(label = "Completion", value = os.environ["COMPLETION"], line_breaks = True, sanitize_html = False)],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"]
)

demo.launch()