File size: 3,031 Bytes
5c38fee
daa80cf
5c38fee
6a5cc80
 
 
 
 
 
 
 
5ef932e
5c38fee
 
12178fd
5c38fee
c26729e
 
9e5685a
5c38fee
 
 
5ef932e
e5984fd
 
 
 
 
5c38fee
 
 
 
 
 
 
 
5ef932e
838b33c
6a5cc80
 
28c3fd0
 
 
 
 
 
 
 
 
 
 
 
ca00ef5
 
28c3fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b7f0c5
481adf0
 
 
 
 
5c38fee
 
 
 
 
 
2d67e97
e5984fd
 
6a5cc80
2d67e97
5c38fee
a086fab
5c38fee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import logging, os, sys, threading

from custom_utils import (
    connect_to_database,
    inference,
    rag_ingestion,
    rag_retrieval_naive,
    rag_retrieval_advanced,
    rag_inference
)

lock = threading.Lock()

RAG_INGESTION = False

RAG_OFF      = "Off"
RAG_NAIVE    = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
    
def invoke(openai_api_key, 
           prompt, 
           accomodates, 
           bedrooms, 
           rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        db, collection = connect_to_database()

        inference_result = ""

        try:
            if (RAG_INGESTION):
                return rag_ingestion(collection)
            elif rag_option == RAG_OFF:
                inference_result = inference(
                    openai_api_key, 
                    prompt
                )
            elif rag_option == RAG_NAIVE:
                retrieval_result = rag_retrieval_naive(
                    openai_api_key, 
                    prompt,
                    accomodates,
                    bedrooms,
                    db, 
                    collection
                )
                inference_result = rag_inference(
                    openai_api_key, 
                    prompt, 
                    retrieval_result
                )        
            elif rag_option == RAG_ADVANCED:
                retrieval_result = rag_retrieval_advanced(
                    openai_api_key, 
                    prompt, 
                    accomodates,
                    bedrooms,
                    db, 
                    collection
                )
                inference_result = rag_inference(
                    openai_api_key, 
                    prompt, 
                    retrieval_result
                )
        except Exception as e:
            raise gr.Error(e)

        print("###")
        print(inference_result)
        print("###")
 
        return inference_result

gr.close_all()

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = os.environ["PROMPT"], lines = 1),
              gr.Number(label = "Accomodates", value = 2),
              gr.Number(label = "Bedrooms", value = 1),
              gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
    outputs = [gr.Markdown(label = "Completion", value = os.environ["COMPLETION"], line_breaks = True, sanitize_html = False)],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"]
)

demo.launch()