File size: 2,758 Bytes
60d2583
 
5c38fee
daa80cf
5c38fee
6a5cc80
 
 
 
 
 
 
 
5ef932e
5c38fee
 
12178fd
5c38fee
c26729e
 
9e5685a
5c38fee
 
 
5ef932e
e5984fd
 
 
 
 
5c38fee
 
 
 
 
 
 
 
5ef932e
838b33c
6a5cc80
 
838b33c
82fd045
84df241
6a5cc80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e7435e
 
 
6a5cc80
 
5c38fee
 
 
 
 
 
2d67e97
e5984fd
 
6a5cc80
2d67e97
5c38fee
a086fab
5c38fee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#TODO: Pydantic, image embedding, clean up data set

import gradio as gr
import logging, os, sys, threading

from custom_utils import (
    connect_to_database,
    inference,
    rag_ingestion,
    rag_retrieval_naive,
    rag_retrieval_advanced,
    rag_inference
)

lock = threading.Lock()

RAG_INGESTION = False

RAG_OFF      = "Off"
RAG_NAIVE    = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
    
def invoke(openai_api_key, 
           prompt, 
           accomodates, 
           bedrooms, 
           rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        db, collection = connect_to_database()

        inference_result = ""
        
        if (RAG_INGESTION):
            return rag_ingestion(collection)
        elif rag_option == RAG_OFF:
            inference_result = inference(
                openai_api_key, 
                prompt)
        elif rag_option == RAG_NAIVE:
            retrieval_result = rag_retrieval_naive(
                openai_api_key, 
                prompt,
                db, 
                collection)
            inference_result = rag_inference(
                openai_api_key, 
                prompt, 
                retrieval_result)        
        elif rag_option == RAG_ADVANCED:
            retrieval_result = rag_retrieval_advanced(
                openai_api_key, 
                prompt, 
                accomodates,
                bedrooms,
                db, 
                collection)
            inference_result = rag_inference(
                openai_api_key, 
                prompt, 
                retrieval_result)

        print("###")
        print(inference_result)
        print("###")
 
        return inference_result

gr.close_all()

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = os.environ["PROMPT"], lines = 1),
              gr.Number(label = "Accomodates", value = 2),
              gr.Number(label = "Bedrooms", value = 1),
              gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
    outputs = [gr.Markdown(label = "Completion", value = os.environ["COMPLETION"], line_breaks = True, sanitize_html = False)],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"]
)

demo.launch()