File size: 4,879 Bytes
5c38fee
daa80cf
5c38fee
4b2f569
5ef932e
5c38fee
 
82fd045
5c38fee
c26729e
 
9e5685a
5c38fee
 
 
5ef932e
5c38fee
 
 
 
 
 
 
 
 
5ef932e
838b33c
 
82fd045
0835210
f8ac3f0
 
3cc300e
 
 
 
 
 
0835210
3cc300e
6eee7c9
ddeba7a
6eee7c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddeba7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eee7c9
 
120d45f
 
 
82fd045
5c38fee
 
 
aabb4c2
838b33c
5c38fee
 
 
838b33c
9e5685a
4af4bf5
5c38fee
a086fab
5c38fee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import logging, os, sys, threading

from custom_utils import connect_to_database, rag_ingestion, rag_retrieval, rag_inference

lock = threading.Lock()

RAG_INGESTION = False

RAG_OFF      = "Off"
RAG_NAIVE    = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
    
def invoke(openai_api_key, prompt, rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        db, collection = connect_to_database()

        if (RAG_INGESTION):
            return rag_ingestion(collection)
        else:
            ### Pre-retrieval processing: index filter
            ### Post-retrieval processing: result filter
            #match_stage = {
            #    "$match": {
            #        "accommodates": { "$eq": 2},
            #        "bedrooms": { "$eq": 1}
            #    }
            #}
    
            #additional_stages = [match_stage]
            ###
            """
            projection_stage = {
                "$project": {
                    "_id": 0,  
                    "name": 1,
                    "accommodates": 1,
                    "address.street": 1,
                    "address.government_area": 1, 
                    "address.market": 1,
                    "address.country": 1, 
                    "address.country_code": 1, 
                    "address.location.type": 1, 
                    "address.location.coordinates": 1,  
                    "address.location.is_location_exact": 1,
                    "summary": 1,
                    "space": 1,  
                    "neighborhood_overview": 1, 
                    "notes": 1, 
                    "score": {"$meta": "vectorSearchScore"} 
                }
            }
            
            additional_stages = [projection_stage]
            """
            ###
            review_average_stage = {
                "$addFields": {
                    "averageReviewScore": {
                        "$divide": [
                            {
                                "$add": [
                                    "$review_scores.review_scores_accuracy",
                                    "$review_scores.review_scores_cleanliness",
                                    "$review_scores.review_scores_checkin",
                                    "$review_scores.review_scores_communication",
                                    "$review_scores.review_scores_location",
                                    "$review_scores.review_scores_value",
                                ]
                            },
                            6  # Divide by the number of review score types to get the average
                        ]
                    },
                    # Calculate a score boost factor based on the number of reviews
                    "reviewCountBoost": "$number_of_reviews"
                }
            }

            weighting_stage = {
                "$addFields": {
                    "combinedScore": {
                        # Example formula that combines average review score and review count boost
                        "$add": [
                            {"$multiply": ["$averageReviewScore", 0.9]},  # Weighted average review score
                            {"$multiply": ["$reviewCountBoost", 0.1]}   # Weighted review count boost
                        ]
                    }
                }
            }

            # Apply the combinedScore for sorting
            sorting_stage_sort = {
                "$sort": {"combinedScore": -1}  # Descending order to boost higher combined scores
            }

            additional_stages = [review_average_stage, weighting_stage, sorting_stage_sort]
            ###
            #additional_stages = []
            ###
            
            search_results = rag_retrieval(openai_api_key, prompt, db, collection, additional_stages)
            return rag_inference(openai_api_key, prompt, search_results)

gr.close_all()

PROMPT = "Recommend a place that's modern, spacious, and within walking distance from restaurants."

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = PROMPT, lines = 1),
              gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
    outputs = [gr.Markdown(label = "Completion")],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"]
)

demo.launch()