chatpdf / index.html
sonoisa's picture
Add a flag to enable/disable RAG feature
187910a verified
raw
history blame
47.7 kB
<!DOCTYPE html>
<!--
Copyright (c) 2024 Isao Sonobe
Released under the MIT license
https://opensource.org/license/mit/
-->
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Chat with your PDF</title>
<meta name="description" content="Chat with your PDF">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.css" />
<style>
html, body {
margin: 0;
padding: 0;
height: 100%;
background: var(--body-background-fill);
}
footer {
display: none !important;
}
#chatbot {
height: auto !important;
min-height: 500px;
}
#chatbot h1 {
font-size: 2em;
margin-block-start: 0.67em;
margin-block-end: 0em;
margin-inline-start: 0px;
margin-inline-end: 0px;
font-weight: bold;
}
#chatbot h2 {
font-size: 1.5em;
margin-block-start: 0.83em;
margin-block-end: 0em;
margin-inline-start: 0px;
margin-inline-end: 0px;
font-weight: bold;
}
#chatbot h3 {
font-size: 1.17em;
margin-block-start: 1em;
margin-block-end: 0em;
margin-inline-start: 0px;
margin-inline-end: 0px;
font-weight: bold;
}
#chatbot h4 {
margin-block-start: 1.33em;
margin-block-end: 0em;
margin-inline-start: 0px;
margin-inline-end: 0px;
font-weight: bold;
}
#chatbot h5 {
margin-block-start: 1.67em;
margin-block-end: 0em;
margin-inline-start: 0px;
margin-inline-end: 0px;
font-weight: bold;
}
#chatbot h6 {
margin-block-start: 1.83em;
margin-block-end: 0em;
margin-inline-start: 0px;
margin-inline-end: 0px;
font-weight: bold;
}
/*
.chatbot {
white-space: pre-wrap;
}
*/
.gallery-item > .gallery {
max-width: 380px;
}
#context > label > textarea {
scrollbar-width: thin !important;
}
#cost_info {
border-style: none !important;
}
#cost_info > label > input {
background: var(--panel-background-fill) !important;
}
</style>
</head>
<body>
<gradio-lite>
<gradio-requirements>
pdfminer.six==20231228
pyodide-http==0.2.1
janome==0.5.0
rank_bm25==0.2.2
</gradio-requirements>
<gradio-file name="chat_history.json">
[[null, "ようこそ! PDFのテキストを参照しながら対話できるチャットボットです。\nPDFファイルをアップロードするとテキストが抽出されます。\nメッセージの中に{context}と書くと、抽出されたテキストがその部分に埋め込まれて対話が行われます。他にもPDFのページを検索して参照したり、ページ番号を指定して参照したりすることができます。一番下のExamplesにこれらの例があります。\nメッセージを書くときにShift+Enterを入力すると改行できます。"]]
</gradio-file>
<gradio-file name="app.py" entrypoint>
import os
# Gradioによるアナリティクスを無効化
os.putenv("GRADIO_ANALYTICS_ENABLED", "False")
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
# openaiライブラリのインストール方法は https://github.com/pyodide/pyodide/issues/4292 を参考にしました。
import micropip
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/multidict/multidict-4.7.6-py3-none-any.whl", keep_going=True)
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/frozenlist/frozenlist-1.4.0-py3-none-any.whl", keep_going=True)
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/aiohttp/aiohttp-4.0.0a2.dev0-py3-none-any.whl", keep_going=True)
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/openai/openai-1.3.7-py3-none-any.whl", keep_going=True)
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/urllib3/urllib3-2.1.0-py3-none-any.whl", keep_going=True)
await micropip.install("ssl")
import ssl
await micropip.install("httpx", keep_going=True)
import httpx
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/urllib3/urllib3-2.1.0-py3-none-any.whl", keep_going=True)
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.2-cp311-cp311-emscripten_3_1_46_wasm32.whl", keep_going=True)
import gradio as gr
import base64
import json
import unicodedata
import re
from pathlib import Path
from dataclasses import dataclass
import asyncio
import pyodide_http
pyodide_http.patch_all()
from pdfminer.pdfinterp import PDFResourceManager
from pdfminer.converter import TextConverter
from pdfminer.pdfinterp import PDFPageInterpreter
from pdfminer.pdfpage import PDFPage
from pdfminer.layout import LAParams
from io import StringIO
from janome.tokenizer import Tokenizer as JanomeTokenizer
from janome.analyzer import Analyzer as JanomeAnalyzer
from janome.tokenfilter import POSStopFilter, LowerCaseFilter
from rank_bm25 import BM25Okapi
from openai import OpenAI, AzureOpenAI
import tiktoken
import requests
class URLLib3Transport(httpx.BaseTransport):
"""
urllib3を使用してhttpxのリクエストを処理するカスタムトランスポートクラス
"""
def __init__(self):
self.pool = urllib3.PoolManager()
def handle_request(self, request: httpx.Request):
payload = json.loads(request.content.decode("utf-8"))
urllib3_response = self.pool.request(request.method, str(request.url), headers=request.headers, json=payload)
stream = httpx.ByteStream(urllib3_response.data)
return httpx.Response(urllib3_response.status, headers=urllib3_response.headers, stream=stream)
http_client = httpx.Client(transport=URLLib3Transport())
@dataclass
class Page:
"""
PDFのページ内容
"""
number: int
content: str
def load_tiktoken_model(model_url):
resp = requests.get(model_url)
resp.raise_for_status()
return resp.content
# OPENAI_TOKENIZER = tiktoken.get_encoding("cl100k_base")
OPENAI_TOKENIZER = tiktoken.Encoding(
name="cl100k_base",
pat_str=r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
mergeable_ranks={
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in load_tiktoken_model("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/cl100k_base.tiktoken").splitlines() if line)
},
special_tokens={
"&lt;|endoftext|&gt;": 100257,
"&lt;|fim_prefix|&gt;": 100258,
"&lt;|fim_middle|&gt;": 100259,
"&lt;|fim_suffix|&gt;": 100260,
"&lt;|endofprompt|&gt;": 100276,
}
)
JANOME_TOKENIZER = JanomeTokenizer()
JANOME_ANALYZER = JanomeAnalyzer(tokenizer=JANOME_TOKENIZER,
token_filters=[POSStopFilter(["記号,空白"]), LowerCaseFilter()])
def extract_pdf_pages(pdf_filename):
"""
PDFファイルからテキストを抽出する。
Args:
pdf_filename (str): 抽出するPDFファイルのパス
Returns:
list[Page]: PDFの各ページ内容のリスト
"""
pages = []
with open(pdf_filename, "rb") as pdf_file:
output = StringIO()
resource_manager = PDFResourceManager()
laparams = LAParams()
text_converter = TextConverter(resource_manager, output, laparams=laparams)
page_interpreter = PDFPageInterpreter(resource_manager, text_converter)
page_number = 0
for i_page in PDFPage.get_pages(pdf_file):
try:
page_number += 1
page_interpreter.process_page(i_page)
page_content = output.getvalue()
page_content = unicodedata.normalize('NFKC', page_content)
pages.append(Page(number=page_number, content=page_content))
output.truncate(0)
output.seek(0)
except Exception as e:
print(e)
pass
output.close()
text_converter.close()
return pages
def merge_pages_with_page_tag(pages):
"""
PDFの各ページ内容を一つの文字列にマージする。
ただし、chatpdf:pageというタグでページを括る。
extract_pages_from_page_tag()の逆変換である。
Args:
pages (list[Page]): PDFの各ページ内容のリスト
Returns:
str: PDFの各ページ内容をマージした文字列
"""
document_with_page_tag = ""
for page in pages:
document_with_page_tag += f'&lt;chatpdf:page number="{page.number}"&gt;\n{page.content}\n&lt;/chatpdf:page&gt;\n'
return document_with_page_tag
def extract_pages_from_page_tag(document_with_page_tag):
"""
chatpdf:pageというタグで括られた領域をPDFのページ内容と解釈して、Pageオブジェクトのリストに変換する。
merge_pages_with_page_tag()の逆変換である。
Args:
document_with_page_tag (str): chatpdf:pageというタグで各ページが括られた文字列
Returns:
list[Page]: Pageオブジェクトのリスト
"""
page_tag_pattern = r'&lt;chatpdf:page number="(\d+)"&gt;\n?(.*?)\n?&lt;\/chatpdf:page&gt;\n?'
matches = re.findall(page_tag_pattern, document_with_page_tag, re.DOTALL)
pages = [Page(number=int(number), content=content) for number, content in matches]
return pages
def escape_latex(unescaped_text):
"""
Chatbotのmarkdownで数式が表示されるように \\(, \\), \\[, \\] をバックスラッシュでエスケープする。
Args:
unescaped_text (str): エスケープ対象文字列
Returns:
str: エスケープされた文字列
"""
return re.sub(r"(\\[\(\)\[\]])", r"\\\1", unescaped_text)
def unescape_latex(escaped_text):
"""
Chatbotのmarkdownで数式が表示されるようにエスケープされていた \\(, \\), \\[, \\] をエスケープされていない元の括弧に変換する。
Args:
escaped_text (str): エスケープされた文字列
Returns:
str: エスケープされていない文字列
"""
return re.sub(r"\\(\\[\(\)\[\]])", r"\1", escaped_text)
def add_s(values):
"""
複数形のsを必要に応じて付けるために用いる関数。
与えられたリストの要素数が2以上なら"s"を返し、それ以外は""を返す。
Args:
values (list[any]): リスト
Returns:
str: 要素数が複数なら"s"、それ以外は""
"""
return "s" if len(values) > 1 else ""
def get_context_info(characters, tokens):
"""
文字数とトークン数の情報を文字列で返す。
Args:
characters (str): テキスト
tokens (list[str]): トークン
Returns:
str: 文字数とトークン数の情報を含む文字列
"""
char_count = len(characters)
token_count = len(tokens)
return f"{char_count:,} character{add_s(characters)}\n{token_count:,} token{add_s(tokens)}"
def update_context_element(pdf_file_obj):
"""
PDFファイルからテキストを抽出し、コンテキスト要素を更新する。
Args:
pdf_file_obj (File): アップロードされたPDFファイルオブジェクト
Returns:
Tuple: コンテキストテキストボックスに格納する抽出されたテキスト情報と、その文字数情報
"""
pages = extract_pdf_pages(pdf_file_obj.name)
document_with_tag = merge_pages_with_page_tag(pages)
return gr.update(value=document_with_tag, interactive=True), count_characters(document_with_tag)
def count_characters(document_with_tag):
"""
テキストの文字数とトークン数を計算する。
ただし、テキストはchatpdf:pageというタグでページが括られているとする。
Args:
document_with_tag (str): 文字数とトークン数を計算するテキスト
Returns:
str: 文字数とトークン数の情報を含む文字列
"""
text = "".join([page.content for page in extract_pages_from_page_tag(document_with_tag)])
tokens = OPENAI_TOKENIZER.encode(text)
return get_context_info(text, tokens)
class SearchEngine:
"""
検索エンジン
"""
def __init__(self, engine, pages):
self.engine = engine
self.pages = pages
SEARCH_ENGINE = None
def create_search_engine(context):
"""
検索エンジンを作る。
Args:
context (str): 検索対象となるテキスト。ただし、テキストはchatpdf:pageというタグでページが括られているとする。
"""
global SEARCH_ENGINE
pages = extract_pages_from_page_tag(context)
tokenized_pages = []
original_pages = []
for page in pages:
page_content = page.content.strip()
if page_content:
tokenized_page = [token.base_form for token in JANOME_ANALYZER.analyze(page_content)]
if tokenized_page:
tokenized_pages.append(tokenized_page)
original_pages.append(page)
if tokenized_pages:
bm25 = BM25Okapi(tokenized_pages)
SEARCH_ENGINE = SearchEngine(engine=bm25, pages=original_pages)
else:
SEARCH_ENGINE = None
def search_pages(keywords, page_limit):
"""
与えられたキーワードを含むページを検索する。
Args:
keywords (str): 検索キーワード
page_limit (int): 検索するページ数
Returns:
list[Page]: ヒットしたページ
"""
global SEARCH_ENGINE
if SEARCH_ENGINE is None:
return []
tokenized_query = [token.base_form for token in JANOME_ANALYZER.analyze(keywords)]
if not tokenized_query:
return []
found_pages = SEARCH_ENGINE.engine.get_top_n(tokenized_query, SEARCH_ENGINE.pages, n=page_limit)
return found_pages
def load_pages(page_numbers):
"""
与えられたページ番号のページを取得する。
Args:
page_numbers (list[int]): 取得するページ番号
Returns:
list[Page]: 取得したページ
"""
global SEARCH_ENGINE
if SEARCH_ENGINE is None:
return []
page_numbers = set(page_numbers)
found_pages = [page for page in SEARCH_ENGINE.pages if page.number in page_numbers]
return found_pages
# function calling用ツール
CHAT_TOOLS = [
# ページ検索
{
"type": "function",
"function": {
"name": "search_pages",
"description": "Searches for pages containing the given keywords.",
"parameters": {
"type": "object",
"properties": {
"keywords": {
"type": "string",
"description": 'Search keywords separated by spaces. For example, "Artificial General Intelligence 自律エージェント".'
},
"page_limit": {
"type": "number",
"description": "Maximum number of search results to return. For example, 3.",
"minimum": 1
}
}
},
"required": ["keywords"]
}
},
# ページ取得
{
"type": "function",
"function": {
"name": "load_pages",
"description": "Loads pages specified by their page numbers.",
"parameters": {
"type": "object",
"properties": {
"page_numbers": {
"type": "array",
"items": {
"type": "number"
},
"description": "List of page numbers to be load",
"minItems": 1
}
}
},
"required": ["page_numbers"]
}
}
]
# function callingなど、固定で消費するトークン数
CHAT_TOOLS_TOKENS = 139
def get_openai_messages(prompt, history, context):
"""
与えられた対話用データを、ChatGPT APIの入力に用いられるメッセージデータ形式に変換して返す。
Args:
prompt (str): ユーザーからの入力プロンプト
history (list[list[str]]): チャット履歴
context (str): チャットコンテキスト
Returns:
str: ChatGPT APIの入力に用いられるメッセージデータ
"""
global SEARCH_ENGINE
if SEARCH_ENGINE is not None:
context = "".join([page.content for page in SEARCH_ENGINE.pages])
messages = []
for user_message, assistant_message in history:
if user_message is not None and assistant_message is not None:
user_message = unescape_latex(user_message)
user_message = user_message.replace("{context}", context)
assistant_message = unescape_latex(assistant_message)
messages.append({ "role": "user", "content": user_message })
messages.append({ "role": "assistant", "content": assistant_message })
prompt = prompt.replace("{context}", context)
messages.append({ "role": "user", "content": prompt })
return messages
# それまでの全入力トークン数
actual_total_cost_prompt = 0
# それまでの全出力トークン数
actual_total_cost_completion = 0
async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag):
"""
ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。
Args:
prompt (str): ユーザーからの入力プロンプト
history (list[list[str]]): チャット履歴
context (str): チャットコンテキスト
platform (str): 使用するAIプラットフォーム
endpoint (str): AIサービスのエンドポイント
azure_deployment (str): Azureのデプロイメント名
azure_api_version (str): Azure APIのバージョン
api_key (str): APIキー
model_name (str): 使用するAIモデルの名前
max_tokens (int): 生成する最大トークン数
temperature (float): クリエイティビティの度合いを示す温度パラメータ
enable_rag (bool): RAG機能を有効にするかどうか
Returns:
str: ChatGPTによる生成結果
"""
global actual_total_cost_prompt, actual_total_cost_completion
try:
messages = get_openai_messages(prompt, history, context)
if platform == "OpenAI":
openai_client = OpenAI(
base_url=endpoint,
api_key=api_key,
http_client=http_client
)
else: # Azure
openai_client = AzureOpenAI(
azure_endpoint=endpoint,
api_version=azure_api_version,
azure_deployment=azure_deployment,
api_key=api_key,
http_client=http_client
)
if enable_rag:
completion = openai_client.chat.completions.create(
messages=messages,
model=model_name,
max_tokens=max_tokens,
temperature=temperature,
tools=CHAT_TOOLS,
tool_choice="auto",
stream=False
)
else:
completion = openai_client.chat.completions.create(
messages=messages,
model=model_name,
max_tokens=max_tokens,
temperature=temperature,
stream=False
)
bot_response = ""
if hasattr(completion, "error"):
raise gr.Error(completion.error["message"])
response_message = completion.choices[0].message
tool_calls = response_message.tool_calls
actual_total_cost_prompt += completion.usage.prompt_tokens
actual_total_cost_completion += completion.usage.completion_tokens
if tool_calls:
messages.append(response_message)
for tool_call in tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
if function_name == "search_pages":
# ページ検索
keywords = function_args.get("keywords").strip()
page_limit = function_args.get("page_limit") or 3
bot_response += f'Searching for pages containing the keyword{add_s(keywords.split(" "))} "{keywords}".\n'
found_pages = search_pages(keywords, page_limit)
function_response = json.dumps({
"status": "found" if found_pages else "not found",
"found_pages": [{
"page_number": page.number,
"page_content": page.content
} for page in found_pages]
}, ensure_ascii=False)
messages.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response
})
if found_pages:
bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
else:
bot_response += "Page not found.\n\n"
elif function_name == "load_pages":
# ページ取得
page_numbers = function_args.get("page_numbers")
bot_response += f'Trying to load page{add_s(page_numbers)} {", ".join(map(str, page_numbers))}.\n'
found_pages = load_pages(page_numbers)
function_response = json.dumps({
"status": "found" if found_pages else "not found",
"found_pages": [{
"page_number": page.number,
"page_content": page.content
} for page in found_pages]
}, ensure_ascii=False)
messages.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response
})
if found_pages:
bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
else:
bot_response += "Page not found.\n\n"
else:
raise gr.Error(f"Unknown function calling '{function_name}'.")
yield bot_response + "Generating response. Please wait a moment...\n"
await asyncio.sleep(0.1)
completion = openai_client.chat.completions.create(
messages=messages,
model=model_name,
max_tokens=max_tokens,
temperature=temperature,
stream=False
)
actual_total_cost_prompt += completion.usage.prompt_tokens
actual_total_cost_completion += completion.usage.completion_tokens
if hasattr(completion, "error"):
raise gr.Error(completion.error["message"])
response_message = completion.choices[0].message
bot_response += response_message.content
yield bot_response
else:
bot_response += response_message.content
yield bot_response
except Exception as e:
if hasattr(e, "message"):
raise gr.Error(e.message)
else:
raise gr.Error(str(e))
def load_api_key(file_obj):
"""
APIキーファイルからAPIキーを読み込む。
Args:
file_obj (File): APIキーファイルオブジェクト
Returns:
str: 読み込まれたAPIキー文字列
"""
try:
with open(file_obj.name, "r", encoding="utf-8") as api_key_file:
return api_key_file.read().strip()
except Exception as e:
raise gr.Error(str(e))
def get_cost_info(prompt_token_count):
"""
チャットのトークン数情報を表示するための文字列を返す。
Args:
prompt_token_count (int): プロンプト(履歴込み)のトークン数
Returns:
str: チャットのトークン数情報を表示するための文字列
"""
return f"Estimated input cost: {prompt_token_count + CHAT_TOOLS_TOKENS:,} tokens, Actual total input cost: {actual_total_cost_prompt:,} tokens, Actual total output cost: {actual_total_cost_completion:,} tokens"
# デフォルト設定値
DEFAULT_SETTINGS = {
"setting_name": "Default",
"platform": "OpenAI",
"endpoint": "https://api.openai.com/v1",
"azure_deployment": "",
"azure_api_version": "",
"model_name": "gpt-4-turbo-preview",
"max_tokens": 4096,
"temperature": 0.2,
"enable_rag": True,
"save_chat_history_to_url": False
};
def main():
"""
アプリケーションのメイン関数。Gradioインターフェースを設定し、アプリケーションを起動する。
"""
try:
# クエリパラメータに保存されていることもあるチャット履歴を読み出す。
with open("chat_history.json", "r", encoding="utf-8") as f:
CHAT_HISTORY = json.load(f)
except Exception as e:
print(e)
CHAT_HISTORY = []
# localStorageから設定情報ををロードする。
js_define_utilities_and_load_settings = """() =&gt; {
const KEY_PREFIX = "serverless_chat_with_your_pdf:";
const loadSettings = () =&gt; {
const getItem = (key, defaultValue) =&gt; {
const jsonValue = localStorage.getItem(KEY_PREFIX + key);
if (jsonValue) {
return JSON.parse(jsonValue);
} else {
return defaultValue;
}
};
""" + "".join([f"""
const default_{setting_key} = {json.dumps(default_value, ensure_ascii=False)};
const {setting_key} = getItem("{setting_key}", default_{setting_key});
""" for setting_key, default_value in DEFAULT_SETTINGS.items()]) + """
const serialized_saved_settings = getItem("saved_settings", []);
const default_saved_settings = [[
""" + ", ".join([f"{json.dumps(default_value, ensure_ascii=False)}" for _, default_value in DEFAULT_SETTINGS.items()]) + """
]];
saved_settings = [];
for (let entry of serialized_saved_settings) {
saved_settings.push([
entry["setting_name"] || "",
entry["platform"] || default_platform,
entry["endpoint"] || default_endpoint,
entry["azure_deployment"] || default_azure_deployment,
entry["azure_api_version"] || default_azure_api_version,
entry["model_name"] || default_model_name,
entry["max_tokens"] || default_max_tokens,
entry["temperature"] || default_temperature,
entry["enable_rag"] || default_enable_rag,
entry["save_chat_history_to_url"] || default_save_chat_history_to_url
]);
}
if (saved_settings.length == 0) {
saved_settings = default_saved_settings;
}
return [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, saved_settings];
};
globalThis.resetSettings = () =&gt; {
for (let key in localStorage) {
if (key.startsWith(KEY_PREFIX) && !key.startsWith(KEY_PREFIX + "saved_settings") && !key.startsWith(KEY_PREFIX + "setting_name")) {
localStorage.removeItem(key);
}
}
return loadSettings();
};
globalThis.saveItem = (key, value) =&gt; {
localStorage.setItem(KEY_PREFIX + key, JSON.stringify(value));
};
return loadSettings();
}
"""
# should_saveがtrueであればURLにチャット履歴を保存し、falseであればチャット履歴を削除する。
save_or_delete_chat_history = '''(hist, should_save) =&gt; {
saveItem("save_chat_history_to_url", should_save);
if (!should_save) {
const url = new URL(window.location.href);
url.searchParams.delete("history");
window.history.replaceState({path:url.href}, '', url.href);
} else {
const compressedHistory = LZString.compressToEncodedURIComponent(JSON.stringify(hist));
const url = new URL(window.location.href);
url.searchParams.set("history", compressedHistory);
window.history.replaceState({path:url.href}, '', url.href);
}
}'''
# メッセージ例
examples = {
"要約 (論文)": '''制約条件に従い、以下の研究論文で提案されている技術や手法について要約してください。
# 制約条件
* 要約者: 大学教授
* 想定読者: 大学院生
* 要約結果の言語: 日本語
* 要約結果の構成(以下の各項目について500文字):
1. どんな研究であるか
2. 先行研究に比べて優れている点は何か
3. 提案されている技術や手法の重要な点は何か
4. どのような方法で有効であると評価したか
5. 何か議論はあるか
6. 次に読むべき論文は何か
# 研究論文
"""
{context}
"""
# 要約結果''',
"要約 (一般)": '''制約条件に従い、以下の文書の内容を要約してください。
# 制約条件
* 要約者: 技術コンサルタント
* 想定読者: 経営層、CTO、CIO
* 形式: 箇条書き
* 分量: 20項目
* 要約結果の言語: 日本語
# 文書
"""
{context}
"""
# 要約''',
"情報抽出": '''制約条件に従い、以下の文書から情報を抽出してください。
# 制約条件
* 抽出する情報: 課題や問題点について言及している全ての文。一つも見落とさないでください。
* 出力形式: 箇条書き
* 出力言語: 元の言語の文章と、その日本語訳
# 文書
"""
{context}
"""
# 抽出結果''',
"QA (日本語文書RAG)": '''次の質問に回答するために役立つページを検索して、その検索結果を使って回答して下さい。
# 制約条件
* 検索クエリの生成方法: 質問文の3つの言い換え(paraphrase)をカンマ区切りで連結した文字列
* 検索クエリの言語: 日本語
* 検索するページ数: 3
* 回答方法:
- 検索結果の情報のみを用いて回答すること。
- 回答に利用した文章のあるページ番号を、それぞれの回答文の文末に付与すること。形式: "(参考ページ番号: 71, 59, 47)"
- 回答に役立つ情報が検索結果内にない場合は「検索結果には回答に役立つ情報がありませんでした。」と回答すること。
* 回答の言語: 日本語
# 質問
どのような方法で、提案された手法が有効であると評価しましたか?
# 回答''',
"QA (英語文書RAG)": '''次の質問に回答するために役立つページを検索して、その検索結果を使って回答して下さい。
# 制約条件
* 検索クエリの生成方法: 質問文の3つの言い換え(paraphrase)をカンマ区切りで連結した文字列
* 検索クエリの言語: 英語
* 検索するページ数: 3
* 回答方法:
- 検索結果の情報のみを用いて回答すること。
- 回答に利用した文章のあるページ番号を、それぞれの回答文の文末に付与すること。形式: "(参考ページ番号: 71, 59, 47)"
- 回答に役立つ情報が検索結果内にない場合は「検索結果には回答に役立つ情報がありませんでした。」と回答すること。
* 回答の言語: 日本語
# 質問
どのような方法で、提案された手法が有効であると評価しましたか?
# 回答''',
"要約 (RAG)": '''次のキーワードを含むページを検索して、その検索結果をページごとに要約して下さい。
# 制約条件
* キーワード: dataset datasets
* 検索するページ数: 3
* 要約結果の言語: 日本語
* 要約の形式:
## ページ番号(例: 12ページ)
- 要約文1
- 要約文2
...
* 要約の分量: 各ページ3項目
# 要約''',
"翻訳 (RAG)": '''次のキーワードを含むページを検索して、その検索結果を日本語に翻訳して下さい。
# 制約条件
* キーワード: dataset datasets
* 検索するページ数: 1
# 翻訳結果''',
"要約 (ページ指定)": '''16〜17ページをページごとに箇条書きで要約して下さい。
# 制約条件
* 要約結果の言語: 日本語
* 要約の形式:
## ページ番号(例: 12ページ)
- 要約文1
- 要約文2
...
* 要約の分量: 各ページ5項目
# 要約''',
"続きを生成": "続きを生成してください。"
}
with gr.Blocks(theme=gr.themes.Default(), analytics_enabled=False) as app:
with gr.Tabs():
with gr.TabItem("Settings"):
with gr.Column():
with gr.Column(variant="panel"):
with gr.Row():
setting_name = gr.Textbox(label="Setting Name", value="Default", interactive=True)
setting_name.change(None, inputs=setting_name, outputs=None,
js='(x) =&gt; saveItem("setting_name", x)', show_progress="hidden")
with gr.Row():
platform = gr.Radio(label="Platform", interactive=True,
choices=["OpenAI", "Azure"], value="OpenAI")
platform.change(None, inputs=platform, outputs=None,
js='(x) =&gt; saveItem("platform", x)', show_progress="hidden")
with gr.Row():
endpoint = gr.Textbox(label="Endpoint", interactive=True)
endpoint.change(None, inputs=endpoint, outputs=None,
js='(x) =&gt; saveItem("endpoint", x)', show_progress="hidden")
azure_deployment = gr.Textbox(label="Azure Deployment", interactive=True)
azure_deployment.change(None, inputs=azure_deployment, outputs=None,
js='(x) =&gt; saveItem("azure_deployment", x)', show_progress="hidden")
azure_api_version = gr.Textbox(label="Azure API Version", interactive=True)
azure_api_version.change(None, inputs=azure_api_version, outputs=None,
js='(x) =&gt; saveItem("azure_api_version", x)', show_progress="hidden")
with gr.Group():
with gr.Row():
api_key_file = gr.File(file_count="single", file_types=["text"],
height=80, label="API Key File")
api_key = gr.Textbox(label="API Key", type="password", interactive=True)
# 注意: 秘密情報をlocalStorageに保存してはならない。他者に秘密情報が盗まれる危険性があるからである。
api_key_file.upload(load_api_key, inputs=api_key_file, outputs=api_key,
show_progress="hidden")
api_key_file.clear(lambda: None, inputs=None, outputs=api_key, show_progress="hidden")
with gr.Row():
model_name = gr.Textbox(label="Model", interactive=True)
model_name.change(None, inputs=model_name, outputs=None,
js='(x) =&gt; saveItem("model_name", x)', show_progress="hidden")
max_tokens = gr.Number(label="Max Tokens", interactive=True,
minimum=0, precision=0, step=1)
max_tokens.change(None, inputs=max_tokens, outputs=None,
js='(x) =&gt; saveItem("max_tokens", x)', show_progress="hidden")
temperature = gr.Slider(label="Temperature", interactive=True,
minimum=0.0, maximum=1.0, step=0.1)
temperature.change(None, inputs=temperature, outputs=None,
js='(x) =&gt; saveItem("temperature", x)', show_progress="hidden")
enable_rag = gr.Checkbox(label="Enable RAG (Retrieval Augmented Generation)", interactive=True)
enable_rag.change(None, inputs=enable_rag, outputs=None,
js='(x) =&gt; saveItem("enable_rag", x)', show_progress="hidden")
save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
reset_button = gr.Button("Reset Settings")
with gr.Column(variant="panel"):
default_saved_settings = list(DEFAULT_SETTINGS.values())
saved_settings_df = gr.Dataframe(
elem_id="saved_settings",
value=[default_saved_settings],
headers=["Name", "Platform", "Endpoint", "Azure Deployment", "Azure API Version", "Model", "Max Tokens", "Temperature", "Enable RAG", "Save Chat History to URL"],
row_count=(0, "dynamic"),
col_count=(10, "fixed"),
datatype=["str", "str", "str", "str", "str", "str", "number", "number", "bool", "bool"],
type="array",
label="Saved Settings",
show_label=True,
interactive=False
)
selected_setting = gr.State(None)
temp_selected_row_index = gr.JSON(value=None, visible=False)
def select_setting(event: gr.SelectData):
return (event.index[0], event.index[1]), event.index[0]
saved_settings_df.select(
select_setting, inputs=None, outputs=[selected_setting, temp_selected_row_index], queue=False, show_progress="hidden"
).then(
None, inputs=temp_selected_row_index, outputs=None, js='(row_index) =&gt; { for (let e of document.querySelectorAll("#saved_settings > div > div > button > svelte-virtual-table-viewport > table > tbody > tr")[row_index].children) { e.classList.add("focus"); } }', queue=False, show_progress="hidden"
)
with gr.Row():
load_saved_settings_button = gr.Button("Load")
append_or_overwrite_saved_settings_button = gr.Button("Append or Overwrite")
delete_saved_settings_button = gr.Button("Delete")
serialized_saved_settings_state = gr.JSON(visible=False)
def load_saved_setting(saved_settings, selected_setting):
if not selected_setting:
return saved_settings
def u(x):
return gr.update(value=x, interactive=True)
row_index = selected_setting[0]
setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url = saved_settings[row_index]
return u(setting_name), u(platform), u(endpoint), u(azure_deployment), u(azure_api_version), u(model_name), u(max_tokens), u(temperature), u(enable_rag), u(save_chat_history_to_url), None
load_saved_settings_button.click(load_saved_setting, inputs=[saved_settings_df, selected_setting], outputs=[setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, selected_setting], queue=False, show_progress="hidden")
def append_or_overwrite_setting(saved_settings, setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url):
setting_name = setting_name.strip()
found = False
new_saved_settings = []
for entry in saved_settings:
if entry[0] == setting_name:
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url])
found = True
else:
new_saved_settings.append(entry)
if not found:
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url])
return new_saved_settings, None
def serialize_saved_settings(saved_settings):
serialization_keys = list(DEFAULT_SETTINGS.keys())
serialized_saved_settings = [
{ k: entry[i] for i, k in enumerate(serialization_keys) }
for entry in saved_settings
]
return serialized_saved_settings
append_or_overwrite_saved_settings_button.click(
append_or_overwrite_setting, inputs=[saved_settings_df, setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden"
).then(
serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden",
).then(
None, inputs=serialized_saved_settings_state, outputs=None, js='(x) =&gt; saveItem("saved_settings", x)', queue=False, show_progress="hidden"
)
def delete_setting(saved_settings, selected_setting):
if not selected_setting:
return saved_settings
row_index = selected_setting[0]
new_saved_settings = saved_settings[0:row_index] + saved_settings[row_index + 1:]
if not new_saved_settings:
new_saved_settings.append(default_saved_settings)
return new_saved_settings, None
delete_saved_settings_button.click(
delete_setting, inputs=[saved_settings_df, selected_setting], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden"
).then(
serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden",
).then(
None, inputs=serialized_saved_settings_state, outputs=None, js='(x) =&gt; saveItem("saved_settings", x)', queue=False, show_progress="hidden"
)
temp_saved_settings = gr.JSON(visible=False)
temp_saved_settings.change(lambda x: x, inputs=temp_saved_settings, outputs=saved_settings_df, queue=False, show_progress="hidden")
setting_items = [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, temp_saved_settings]
reset_button.click(None, inputs=None, outputs=setting_items,
js="() =&gt; resetSettings()", show_progress="hidden")
with gr.TabItem("Chat"):
with gr.Row():
with gr.Column(scale=1):
pdf_file = gr.File(file_count="single", file_types=[".pdf"],
height=80, label="PDF")
context = gr.Textbox(elem_id="context", label="Context", lines=20,
interactive=True, autoscroll=False, show_copy_button=True)
char_counter = gr.Textbox(label="Statistics", value=get_context_info("", []),
lines=2, max_lines=2, interactive=False, container=True)
pdf_file.upload(update_context_element, inputs=pdf_file, outputs=[context, char_counter], queue=False)
pdf_file.clear(lambda: None, inputs=None, outputs=context, queue=False, show_progress="hidden")
with gr.Column(scale=2):
additional_inputs = [context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag]
with gr.Blocks() as chat:
gr.Markdown(f"# Chat with your PDF")
with gr.Column(variant="panel"):
chatbot = gr.Chatbot(
CHAT_HISTORY,
elem_id="chatbot", height=500, show_copy_button=True,
sanitize_html=True, render_markdown=True,
latex_delimiters=[
# { "left": "$$", "right": "$$", "display": True },
# { "left": "$", "right": "$", "display": False },
{ "left": "\\(", "right": "\\)", "display": False },
{ "left": "\\[", "right": "\\]", "display": True },
],
likeable=False, layout="bubble",
avatar_images=[None, "https://raw.githubusercontent.com/sonoisa/misc/main/resources/icons/chatbot_icon.png"]
)
message_state = gr.State()
chatbot_state = gr.State(chatbot.value) if chatbot.value else gr.State([])
with gr.Group():
with gr.Row():
message_textbox = gr.Textbox(placeholder="Type a message...",
container=False, show_label=False, autofocus=True, interactive=True, scale=7)
submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=150)
stop_button = gr.Button("Stop", variant="stop", visible=False, scale=1, min_width=150)
cost_info = gr.Textbox(elem_id="cost_info", value=get_cost_info(0),
lines=1, max_lines=1, interactive=False, container=False, elem_classes="cost_info")
with gr.Row():
retry_button = gr.Button("🔄 Retry", variant="secondary", size="sm")
undo_button = gr.Button("↩️ Undo", variant="secondary", size="sm")
clear_button = gr.Button("🗑️ Clear", variant="secondary", size="sm")
def estimate_message_cost(prompt, history, context):
token_count = 0
messages = get_openai_messages(prompt, history, context)
for message in messages:
tokens = OPENAI_TOKENIZER.encode(message["content"])
token_count += len(tokens)
return gr.update(value=get_cost_info(token_count))
message_textbox.change(estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, queue=False, show_progress="hidden")
example_title_textbox = gr.Textbox(visible=False, interactive=True)
gr.Examples([[k] for k, v in examples.items()],
inputs=example_title_textbox, outputs=message_textbox,
fn=lambda title: examples[title], run_on_click=True)
def append_message_to_history(message, history):
message = escape_latex(message)
history.append([message, None])
return history, history
def undo_chat(history):
if history:
message, _ = history.pop()
message = message or ""
else:
message = ""
return history, history, unescape_latex(message)
async def submit_message(message, history_with_input, *args):
history = history_with_input[:-1]
inputs = [message, history]
inputs.extend(args)
generator = process_prompt(*inputs)
message = escape_latex(message)
has_response = False
async for response in generator:
has_response = True
response = escape_latex(response)
update = history + [[message, response]]
yield update, update
if not has_response:
update = history + [[message, None]]
yield update, update
submit_triggers = [message_textbox.submit, submit_button.click]
submit_event = gr.events.on(
submit_triggers, lambda message: ("", message), inputs=[message_textbox], outputs=[message_textbox, message_state], queue=False
).then(
append_message_to_history, inputs=[message_state, chatbot_state], outputs=[chatbot, chatbot_state], queue=False
).then(
submit_message, inputs=[message_state, chatbot_state] + additional_inputs, outputs=[chatbot, chatbot_state]
).then(
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
)
for submit_trigger in submit_triggers:
submit_trigger(lambda: (gr.update(visible=False), gr.update(visible=True)),
inputs=None, outputs=[submit_button, stop_button], queue=False)
submit_event.then(lambda: (gr.update(visible=True), gr.update(visible=False)),
inputs=None, outputs=[submit_button, stop_button], queue=False)
stop_button.click(None, inputs=None, outputs=None, cancels=submit_event)
retry_button.click(
undo_chat, inputs=[chatbot_state], outputs=[chatbot, chatbot_state, message_state], queue=False
).then(
append_message_to_history, inputs=[message_state, chatbot_state], outputs=[chatbot, chatbot_state], queue=False
).then(
submit_message, inputs=[message_state, chatbot_state] + additional_inputs, outputs=[chatbot, chatbot_state]
).then(
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
)
undo_button.click(
undo_chat, inputs=[chatbot_state], outputs=[chatbot, chatbot_state, message_state], queue=False
).then(
lambda message: message, inputs=message_state, outputs=message_textbox, queue=False
).then(
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
)
clear_button.click(
lambda: ([], [], None), inputs=None, outputs=[chatbot, chatbot_state, message_state],
queue=False
).then(
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
)
chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
# チャット履歴をクエリパラメータに保存する。
js=save_or_delete_chat_history, queue=False, show_progress="hidden")
save_chat_history_to_url.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
js=save_or_delete_chat_history, queue=False, show_progress="hidden")
context.change(
count_characters, inputs=context, outputs=char_counter, queue=False, show_progress="hidden"
).then(
create_search_engine, inputs=context, outputs=None
).then(
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
)
app.load(None, inputs=None, outputs=setting_items, js=js_define_utilities_and_load_settings, show_progress="hidden")
app.queue().launch()
main()
</gradio-file>
</gradio-lite>
<script language="javascript" src="https://cdn.jsdelivr.net/npm/lz-string@1.5.0/libs/lz-string.min.js"></script>
<script language="javascript">
(function () {
// クエリパラメータにチャット履歴が記録されていたらそれをロードし、chat_history.jsonファイルに書き出す。
const url = new URL(window.location.href);
if (url.searchParams.has("history")) {
const compressedHistory = url.searchParams.get("history");
hist = LZString.decompressFromEncodedURIComponent(compressedHistory);
const chat_history_element = document.querySelector('gradio-file[name="chat_history.json"]');
chat_history_element.textContent = hist;
}
})();
</script>
<script type="module" crossorigin src="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.js"></script>
</body>
</html>