Spaces:
Sleeping
Sleeping
from src.information_retrieval import info_retrieval as ir | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG) | |
from src.augmentation.prompts import SYSTEM_PROMPT, SUSTAINABILITY_PROMPT, USER_PROMPT | |
def generate_prompt(query, context, template=None): | |
""" | |
Function that generates the prompt given the user query and retrieved context. A specific prompt template will be | |
used if provided, otherwise the default base_prompt template is used. | |
Args: | |
- query: str | |
- context: list[dict] | |
- template: str | |
""" | |
if template: | |
SYS_PROMPT = template | |
else: | |
SYS_PROMPT = SYSTEM_PROMPT | |
formatted_prompt = f"{USER_PROMPT.format(query, context)}" | |
messages = [ | |
{"role": "system", "content": SYS_PROMPT}, | |
{"role": "user", "content": formatted_prompt} | |
] | |
return messages | |
def format_context(context): | |
""" | |
Function that formats the retrieved context in a way that is easy for the LLM to understand. | |
Args: | |
- context: list[dict]; retrieved context | |
""" | |
formatted_context = '' | |
for i, (city, info) in enumerate(context.items()): | |
text = f"Option {i + 1}: {city} is a city in {info['country']}." | |
info_text = f"Here is some information about the city. {info['text']}" | |
attractions_text = "Here are some attractions: " | |
att_flag = 0 | |
restaurants_text = "Here are some places to eat/drink: " | |
rest_flag = 0 | |
hotels_text = "Here are some hotels: " | |
hotel_flag = 0 | |
if len(info['listings']): | |
for listing in info['listings']: | |
if listing['type'] in ['see', 'do', 'go', 'view']: | |
att_flag = 1 | |
attractions_text += f"{listing['name']} ({listing['description']}), " | |
elif listing['type'] in ['eat', 'drink']: | |
rest_flag = 1 | |
restaurants_text += f"{listing['name']} ({listing['description']}), " | |
else: | |
hotel_flag = 1 | |
hotels_text += f"{listing['name']} ({listing['description']}), " | |
# If we add sustainability in the end then it could get truncated because of context window | |
if "sustainability" in info: | |
if info['sustainability']['month'] == 'No data available': | |
sfairness_text = "This city has no sustainability (or s-fairness) score available." | |
else: | |
sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n " | |
text += sfairness_text | |
text += info_text | |
if att_flag: | |
text += f"\n{attractions_text}" | |
if rest_flag: | |
text += f"\n{restaurants_text}" | |
if hotel_flag: | |
text += f"\n{hotels_text}" | |
formatted_context += text + "\n\n " | |
return formatted_context | |
def augment_prompt(query: str, starting_point: str, context: dict, **params: dict): | |
""" | |
Function that accepts the user query as input, obtains relevant documents and augments the prompt with the | |
retrieved context, which can be passed to the LLM. | |
Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested | |
dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness | |
scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets | |
the limit of results and whether to rerank the results | |
""" | |
# what about the cities without s-fairness scores? i.e. they don't have seasonality data | |
updated_query = f"With {starting_point} as the starting point, {query}" | |
prompt_with_sustainability = SUSTAINABILITY_PROMPT | |
# format context | |
formatted_context = format_context(context) | |
if "sustainability" in params["params"] and params["params"]["sustainability"]: | |
prompt = generate_prompt(updated_query, formatted_context, prompt_with_sustainability) | |
else: | |
prompt = generate_prompt(updated_query, formatted_context) | |
return prompt | |
def test(): | |
context_params = { | |
'limit': 3, | |
'reranking': 0, | |
'sustainability': 0 | |
} | |
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \ | |
"in winter. " | |
# without sustainability | |
context = ir.get_context(query, **context_params) | |
without_sfairness = augment_prompt( | |
query=query, | |
context=context, | |
params=context_params | |
) | |
# with sustainability | |
context_params.update({'sustainability': 1}) | |
s_context = ir.get_context(query, **context_params) | |
# s_formatted_context = format_context(s_context) | |
with_sfairness = augment_prompt( | |
query=query, | |
context=s_context, | |
params=context_params | |
) | |
return with_sfairness | |
if __name__ == "__main__": | |
prompt = test() | |
print(prompt) | |