|
import os |
|
from dataclasses import dataclass |
|
from datetime import datetime |
|
from typing import Generator, Dict, List |
|
|
|
from googleapiclient.discovery import build |
|
from streamlit import secrets |
|
|
|
INSTRUCTIONS = "Instructions: " \ |
|
"Using the provided web search results, " \ |
|
"write a comprehensive reply to the given query. " \ |
|
"Make sure to cite results using [[number](URL)] notation after the reference. " \ |
|
"If the provided search results refer to multiple subjects with the same name, " \ |
|
"write separate answers for each subject." |
|
|
|
|
|
def get_google_api_key(): |
|
"""Returns the Google API key from streamlit's secrets""" |
|
try: |
|
return secrets["google_search_api_key"] |
|
except FileNotFoundError: |
|
return os.environ["google_search_api_key"] |
|
|
|
|
|
def get_google_cse_id(): |
|
"""Returns the Google CSE ID from streamlit's secrets""" |
|
try: |
|
return secrets["google_cse_id"] |
|
except FileNotFoundError: |
|
return os.environ["google_cse_id"] |
|
|
|
|
|
def google_search(search_term, **kwargs) -> list: |
|
service = build("customsearch", "v1", developerKey=get_google_api_key()) |
|
search_engine = service.cse() |
|
res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute() |
|
return res['items'] |
|
|
|
|
|
@dataclass |
|
class SearchResult: |
|
__slots__ = ["title", "body", "url"] |
|
title: str |
|
body: str |
|
url: str |
|
|
|
|
|
def get_web_search_results( |
|
query: str, |
|
num_results: int, |
|
) -> Generator[SearchResult, None, None]: |
|
"""Gets a list of web search results using the Google search API""" |
|
rew_results: List[Dict[str, str]] = google_search( |
|
search_term=query, |
|
num=num_results |
|
)[:num_results] |
|
for result in rew_results: |
|
if result["snippet"].endswith("\xa0..."): |
|
result["snippet"] = result["snippet"][:-4] |
|
yield SearchResult( |
|
title=result["title"], |
|
body=result["snippet"], |
|
url=result["link"], |
|
) |
|
|
|
|
|
def format_search_result(search_result: Generator[SearchResult, None, None]) -> str: |
|
"""Formats a search result to be added to the prompt.""" |
|
ans = "" |
|
for i, result in enumerate(search_result): |
|
ans += f"[{i}] {result.body}\nURL: {result.url}\n\n" |
|
return ans |
|
|
|
|
|
def rewrite_prompt( |
|
prompt: str, |
|
) -> str: |
|
"""Rewrites the prompt by adding web search results to it.""" |
|
raw_results = get_web_search_results( |
|
query=prompt, |
|
num_results=5, |
|
) |
|
formatted_results = "Web search results:\n" + format_search_result(raw_results) |
|
formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y") |
|
formatted_prompt = f"Query: {prompt}" |
|
return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt]) |
|
|