File size: 2,043 Bytes
5c38fee
daa80cf
5c38fee
4b2f569
5ef932e
5c38fee
 
82fd045
5c38fee
c26729e
 
9e5685a
5c38fee
 
 
5ef932e
5c38fee
 
 
 
 
 
 
 
 
5ef932e
838b33c
 
82fd045
0835210
f8ac3f0
 
3cc300e
 
 
 
 
 
0835210
3cc300e
 
120d45f
 
 
82fd045
5c38fee
 
 
aabb4c2
838b33c
5c38fee
 
 
838b33c
9e5685a
4af4bf5
5c38fee
a086fab
5c38fee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import logging, os, sys, threading

from custom_utils import connect_to_database, rag_ingestion, rag_retrieval, rag_inference

lock = threading.Lock()

RAG_INGESTION = False

RAG_OFF      = "Off"
RAG_NAIVE    = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
    
def invoke(openai_api_key, prompt, rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        db, collection = connect_to_database()

        if (RAG_INGESTION):
            return rag_ingestion(collection)
        else:
            ### Pre-retrieval processing: index filter
            ### Post-retrieval processing: result filter
            #match_stage = {
            #    "$match": {
            #        "accommodates": { "$eq": 2},
            #        "bedrooms": { "$eq": 1}
            #    }
            #}
    
            #additional_stages = [match_stage]
            additional_stages = []
            ###
            
            search_results = rag_retrieval(openai_api_key, prompt, db, collection, additional_stages)
            return rag_inference(openai_api_key, prompt, search_results)

gr.close_all()

PROMPT = "Recommend a place that's modern, spacious, and within walking distance from restaurants."

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = PROMPT, lines = 1),
              gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
    outputs = [gr.Markdown(label = "Completion")],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"]
)

demo.launch()