Spaces:
Runtime error
Runtime error
# https://github.com/anthropics/anthropic-cookbook/blob/main/third_party/Brave/web_search_using_brave.ipynb | |
import asyncio | |
import html | |
import json | |
import os | |
from typing import List | |
from anthropic import Anthropic | |
import requests | |
import streamlit as st | |
from googleapiclient.discovery import build | |
st.title("Qiitaに聞いた!!") | |
if "client" not in st.session_state: | |
st.session_state.client = Anthropic( | |
api_key=os.environ.get("ANTHROPIC_API_KEY"), | |
) | |
client = st.session_state.client | |
# 検索クエリを生成する関数 | |
def generate_search_queries(question: str) -> List[str]: | |
""" | |
Google 検索エンジン用の検索クエリを生成する | |
""" | |
GENERATE_QUERIES = """ | |
User question: {{question}} | |
Format: {"queries": ["query_1", "query_2", "query_3"]} | |
""" | |
response = client.messages.create( | |
max_tokens=1024, | |
system="You are an expert at generating search queries for the Google search engine. Generate two search queries that are relevant to this question in Japanese. Output only valid JSON.", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": GENERATE_QUERIES.replace("{{question}}", question), | |
} | |
], | |
} | |
], | |
temperature=0, | |
model="claude-3-haiku-20240307", | |
) | |
search_queries = response.content[0].text | |
search_queries = json.loads(search_queries) | |
return search_queries | |
# Qiitaを検索する関数 | |
def search_qiita(search_query: str) -> list: | |
""" | |
指定された検索クエリでQiitaを検索する | |
""" | |
service = build("customsearch", "v1", developerKey=os.environ.get("GOOGLE_API_KEY")) | |
cse = service.cse() | |
res = cse.list( | |
q=f"{search_query} site:qiita.com", | |
cx=os.environ.get("GOOGLE_CSE_ID"), | |
num=3, | |
).execute() | |
documents = list( | |
map( | |
lambda x: { | |
"title": x["title"], | |
"link": x["link"], | |
"snippet": x["snippet"], | |
}, | |
res["items"], | |
) | |
) | |
return documents | |
# 検索結果にマークダウンを追加する非同期関数 | |
async def add_markdown(search_result: dict) -> dict: | |
""" | |
検索結果にマークダウンを追加する | |
""" | |
url = search_result["link"] | |
response = requests.get(f"{url}.md") | |
markdown = response.text | |
search_result["markdown"] = html.escape(markdown) | |
return search_result | |
# 検索結果をXML形式のドキュメントに変換する関数 | |
def create_xml_documents(documents: list) -> str: | |
""" | |
検索結果をXML形式のドキュメントに変換する | |
""" | |
xml_documents = "" | |
xml_doc = list( | |
map( | |
lambda x: f'<doc title="{x["title"]}"><link>{x["link"]}</link><markdown>{x["markdown"]}</markdown></doc>', | |
documents, | |
) | |
) | |
xml_documents = f"<documents>{''.join(xml_doc)}</documents>" | |
return xml_documents | |
# 質問に対する回答を生成する関数 | |
def generate_answer(question: str, documents: dict): | |
""" | |
検索結果から質問に対する回答を生成する | |
""" | |
xml_docs = create_xml_documents(documents=documents) | |
ANSWER_QUESTION = f"""I have provided you with the following search results: | |
{xml_docs} | |
Please answer the user's question using only information from the search results. | |
Keep your answer concise. | |
Answer is olways in Japanese! | |
User's question: {question} | |
""" | |
response = client.messages.create( | |
max_tokens=1024, | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": ANSWER_QUESTION, | |
} | |
], | |
} | |
], | |
temperature=0.1, | |
model="claude-3-haiku-20240307", | |
) | |
return response.content[0].text | |
# メイン関数 | |
async def main(): | |
with st.form("Form"): | |
question = st.text_input("質問") | |
if st.form_submit_button("質問する"): | |
with st.status("処理中...", expanded=True) as status: | |
search_queries = generate_search_queries(question=question) | |
st.write("検索クエリ: " + str(search_queries["queries"])) | |
documents = [] | |
for search_query in search_queries["queries"]: | |
search_results = search_qiita(search_query=search_query) | |
result = await asyncio.gather( | |
*[add_markdown(x) for x in search_results] | |
) | |
documents.extend(result) | |
st.write("検索完了") | |
st.write("回答生成中...") | |
answer = generate_answer(question=question, documents=documents) | |
status.update(label="complete!", state="complete", expanded=False) | |
st.markdown(answer) | |
st.divider() | |
st.markdown("参照ドキュメント") | |
for document in documents: | |
st.markdown( | |
f'[{document["title"]}]({document["link"]}) by {document["link"].split("/")[3]}' | |
) | |
if __name__ == "__main__": | |
asyncio.run(main()) | |