File size: 5,159 Bytes
8842640
4b722ec
 
 
 
adbebe0
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adbebe0
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adbebe0
4b722ec
 
 
 
 
 
 
 
 
 
 
 
adbebe0
 
4b722ec
 
 
 
 
adbebe0
4b722ec
adbebe0
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from src.information_retrieval import info_retrieval as ir
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
from src.augmentation.prompts import SYSTEM_PROMPT, SUSTAINABILITY_PROMPT, USER_PROMPT

def generate_prompt(query, context, template=None):
    """
    Function that generates the prompt given the user query and retrieved context. A specific prompt template will be
    used if provided, otherwise the default base_prompt template is used.

    Args: 
        - query: str
        - context: list[dict]
        - template: str
    
    """

    if template:
        SYS_PROMPT = template
    else:
        SYS_PROMPT = SYSTEM_PROMPT

    formatted_prompt = f"{USER_PROMPT.format(query, context)}"
    messages = [
        {"role": "system", "content": SYS_PROMPT},
        {"role": "user", "content": formatted_prompt}
    ]

    return messages


def format_context(context):
    """
    Function that formats the retrieved context in a way that is easy for the LLM to understand. 

    Args:
        - context: list[dict]; retrieved context 
    
    """

    formatted_context = ''

    for i, (city, info) in enumerate(context.items()):

        text = f"Option {i + 1}: {city} is a city in {info['country']}."
        info_text = f"Here is some information about the city. {info['text']}"

        attractions_text = "Here are some attractions: "
        att_flag = 0
        restaurants_text = "Here are some places to eat/drink: "
        rest_flag = 0

        hotels_text = "Here are some hotels: "
        hotel_flag = 0

        if len(info['listings']):
            for listing in info['listings']:
                if listing['type'] in ['see', 'do', 'go', 'view']:
                    att_flag = 1
                    attractions_text += f"{listing['name']} ({listing['description']}), "
                elif listing['type'] in ['eat', 'drink']:
                    rest_flag = 1
                    restaurants_text += f"{listing['name']} ({listing['description']}), "
                else:
                    hotel_flag = 1
                    hotels_text += f"{listing['name']} ({listing['description']}), "

        # If we add sustainability in the end then it could get truncated because of context window
        if "sustainability" in info:
            if info['sustainability']['month'] == 'No data available':
                sfairness_text = "This city has no sustainability (or s-fairness) score available."

            else:
                sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n "

            text += sfairness_text

        text += info_text

        if att_flag:
            text += f"\n{attractions_text}"

        if rest_flag:
            text += f"\n{restaurants_text}"

        if hotel_flag:
            text += f"\n{hotels_text}"

        formatted_context += text + "\n\n "

    return formatted_context


def augment_prompt(query: str, starting_point: str, context: dict, **params: dict):
    """
    Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
    retrieved context, which can be passed to the LLM.

    Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested
    dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness
    scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets
    the limit of results and whether to rerank the results
    
    """

    # what about the cities without s-fairness scores? i.e. they don't have seasonality data 
    updated_query = f"With {starting_point} as the starting point, {query}"
    prompt_with_sustainability = SUSTAINABILITY_PROMPT

    # format context
    formatted_context = format_context(context)

    if "sustainability" in params["params"] and params["params"]["sustainability"]:
        prompt = generate_prompt(updated_query, formatted_context, prompt_with_sustainability)
    else:
        prompt = generate_prompt(updated_query, formatted_context)

    return prompt


def test():
    context_params = {
        'limit': 3,
        'reranking': 0, 
        'sustainability': 0
    }

    query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
            "in winter. "

    # without sustainability
    context = ir.get_context(query, **context_params)

    without_sfairness = augment_prompt(
        query=query,
        context=context,
        params=context_params
    )

    # with sustainability
    context_params.update({'sustainability': 1})
    s_context = ir.get_context(query, **context_params)
    # s_formatted_context = format_context(s_context)

    with_sfairness = augment_prompt(
        query=query,
        context=s_context,
        params=context_params
    )

    return with_sfairness


if __name__ == "__main__":
    prompt = test()
    print(prompt)