File size: 2,222 Bytes
5c38fee
5ef932e
d6c20fa
5c38fee
5ef932e
d6c20fa
838b33c
5ef932e
d6c20fa
 
5ef932e
d6c20fa
5c38fee
 
 
d6c20fa
5c38fee
c700119
5c38fee
c26729e
 
9e5685a
5c38fee
 
 
5ef932e
5c38fee
 
 
 
 
 
 
 
 
5ef932e
838b33c
 
dfc9647
5ef932e
aaa9ea2
b520825
5ef932e
 
b520825
5ef932e
 
 
 
b520825
5ef932e
 
b520825
 
5ef932e
aaa9ea2
5ef932e
b520825
aaa9ea2
03159c2
5c38fee
 
 
9a01e1d
838b33c
5c38fee
 
 
838b33c
9e5685a
4af4bf5
5c38fee
a086fab
5c38fee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import pandas as pd
import logging, os, sys, threading

from datasets import load_dataset
#from dotenv import load_dotenv, find_dotenv
from custom_utils import connect_to_database, rag_ingestion, handle_user_prompt

#from pydantic import BaseModel
#from typing import Optional

#from IPython.display import display, HTML

lock = threading.Lock()

#_ = load_dotenv(find_dotenv())

RAG_INGESTION = True

RAG_OFF      = "Off"
RAG_NAIVE    = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
    
def invoke(openai_api_key, prompt, rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        db, collection = connect_to_database()

        if (RAG_INGESTION):
            rag_ingestion()

        """
        print("777")
        search_path = "address.country"

        print("888")
        # Create a match stage
        match_stage = {
            "$match": {
               search_path: re.compile(r"United States"),
               "accommodates": { "$gt": 1, "$lt": 3}
            }
        }

        print("999")
        additional_stages = [match_stage]
        """

        print("000")
        #result = handle_user_query(openai_api_key, query, db, collection, additional_stages)
        return handle_user_prompt(openai_api_key, prompt, db, collection)

gr.close_all()

PROMPT = "I want to stay in a place that's modern and clean, walking distance from restaurants."

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = PROMPT, lines = 1),
              gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
    outputs = [gr.Markdown(label = "Completion")],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"]
)

demo.launch()