grouped-sampling-demo / prompt_engeneering.py
yonikremer's picture
improved error-handling
d029425
raw
history blame
2.82 kB
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Generator, Dict, List
from googleapiclient.discovery import build
from streamlit import secrets
INSTRUCTIONS = "Instructions: " \
"Using the provided web search results, " \
"write a comprehensive reply to the given query. " \
"Make sure to cite results using [[number](URL)] notation after the reference. " \
"If the provided search results refer to multiple subjects with the same name, " \
"write separate answers for each subject."
def get_google_api_key():
"""Returns the Google API key from streamlit's secrets"""
try:
return secrets["google_search_api_key"]
except FileNotFoundError:
return os.environ["google_search_api_key"]
def get_google_cse_id():
"""Returns the Google CSE ID from streamlit's secrets"""
try:
return secrets["google_cse_id"]
except FileNotFoundError:
return os.environ["google_cse_id"]
def google_search(search_term, **kwargs) -> list:
service = build("customsearch", "v1", developerKey=get_google_api_key())
search_engine = service.cse()
res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
return res['items']
@dataclass
class SearchResult:
__slots__ = ["title", "body", "url"]
title: str
body: str
url: str
def get_web_search_results(
query: str,
num_results: int,
) -> Generator[SearchResult, None, None]:
"""Gets a list of web search results using the Google search API"""
rew_results: List[Dict[str, str]] = google_search(
search_term=query,
num=num_results
)[:num_results]
for result in rew_results:
if result["snippet"].endswith("\xa0..."):
result["snippet"] = result["snippet"][:-4]
yield SearchResult(
title=result["title"],
body=result["snippet"],
url=result["link"],
)
def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
"""Formats a search result to be added to the prompt."""
ans = ""
for i, result in enumerate(search_result):
ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
return ans
def rewrite_prompt(
prompt: str,
) -> str:
"""Rewrites the prompt by adding web search results to it."""
raw_results = get_web_search_results(
query=prompt,
num_results=5,
)
formatted_results = "Web search results:\n" + format_search_result(raw_results)
formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
formatted_prompt = f"Query: {prompt}"
return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])