Spaces:
Sleeping
Sleeping
File size: 5,159 Bytes
8842640 4b722ec adbebe0 4b722ec adbebe0 4b722ec adbebe0 4b722ec adbebe0 4b722ec adbebe0 4b722ec adbebe0 4b722ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from src.information_retrieval import info_retrieval as ir
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
from src.augmentation.prompts import SYSTEM_PROMPT, SUSTAINABILITY_PROMPT, USER_PROMPT
def generate_prompt(query, context, template=None):
"""
Function that generates the prompt given the user query and retrieved context. A specific prompt template will be
used if provided, otherwise the default base_prompt template is used.
Args:
- query: str
- context: list[dict]
- template: str
"""
if template:
SYS_PROMPT = template
else:
SYS_PROMPT = SYSTEM_PROMPT
formatted_prompt = f"{USER_PROMPT.format(query, context)}"
messages = [
{"role": "system", "content": SYS_PROMPT},
{"role": "user", "content": formatted_prompt}
]
return messages
def format_context(context):
"""
Function that formats the retrieved context in a way that is easy for the LLM to understand.
Args:
- context: list[dict]; retrieved context
"""
formatted_context = ''
for i, (city, info) in enumerate(context.items()):
text = f"Option {i + 1}: {city} is a city in {info['country']}."
info_text = f"Here is some information about the city. {info['text']}"
attractions_text = "Here are some attractions: "
att_flag = 0
restaurants_text = "Here are some places to eat/drink: "
rest_flag = 0
hotels_text = "Here are some hotels: "
hotel_flag = 0
if len(info['listings']):
for listing in info['listings']:
if listing['type'] in ['see', 'do', 'go', 'view']:
att_flag = 1
attractions_text += f"{listing['name']} ({listing['description']}), "
elif listing['type'] in ['eat', 'drink']:
rest_flag = 1
restaurants_text += f"{listing['name']} ({listing['description']}), "
else:
hotel_flag = 1
hotels_text += f"{listing['name']} ({listing['description']}), "
# If we add sustainability in the end then it could get truncated because of context window
if "sustainability" in info:
if info['sustainability']['month'] == 'No data available':
sfairness_text = "This city has no sustainability (or s-fairness) score available."
else:
sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n "
text += sfairness_text
text += info_text
if att_flag:
text += f"\n{attractions_text}"
if rest_flag:
text += f"\n{restaurants_text}"
if hotel_flag:
text += f"\n{hotels_text}"
formatted_context += text + "\n\n "
return formatted_context
def augment_prompt(query: str, starting_point: str, context: dict, **params: dict):
"""
Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
retrieved context, which can be passed to the LLM.
Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested
dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness
scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets
the limit of results and whether to rerank the results
"""
# what about the cities without s-fairness scores? i.e. they don't have seasonality data
updated_query = f"With {starting_point} as the starting point, {query}"
prompt_with_sustainability = SUSTAINABILITY_PROMPT
# format context
formatted_context = format_context(context)
if "sustainability" in params["params"] and params["params"]["sustainability"]:
prompt = generate_prompt(updated_query, formatted_context, prompt_with_sustainability)
else:
prompt = generate_prompt(updated_query, formatted_context)
return prompt
def test():
context_params = {
'limit': 3,
'reranking': 0,
'sustainability': 0
}
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
"in winter. "
# without sustainability
context = ir.get_context(query, **context_params)
without_sfairness = augment_prompt(
query=query,
context=context,
params=context_params
)
# with sustainability
context_params.update({'sustainability': 1})
s_context = ir.get_context(query, **context_params)
# s_formatted_context = format_context(s_context)
with_sfairness = augment_prompt(
query=query,
context=s_context,
params=context_params
)
return with_sfairness
if __name__ == "__main__":
prompt = test()
print(prompt)
|