|
<!DOCTYPE html> |
|
|
|
|
|
|
|
|
|
|
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1"> |
|
<title>Chat with your PDF</title> |
|
<meta name="description" content="Chat with your PDF"> |
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.css" /> |
|
<style> |
|
html, body { |
|
margin: 0; |
|
padding: 0; |
|
height: 100%; |
|
background: var(--body-background-fill); |
|
} |
|
|
|
footer { |
|
display: none !important; |
|
} |
|
|
|
#chatbot { |
|
height: auto !important; |
|
min-height: 500px; |
|
} |
|
|
|
#chatbot h1 { |
|
font-size: 2em; |
|
margin-block-start: 0.67em; |
|
margin-block-end: 0em; |
|
margin-inline-start: 0px; |
|
margin-inline-end: 0px; |
|
font-weight: bold; |
|
} |
|
|
|
#chatbot h2 { |
|
font-size: 1.5em; |
|
margin-block-start: 0.83em; |
|
margin-block-end: 0em; |
|
margin-inline-start: 0px; |
|
margin-inline-end: 0px; |
|
font-weight: bold; |
|
} |
|
|
|
#chatbot h3 { |
|
font-size: 1.17em; |
|
margin-block-start: 1em; |
|
margin-block-end: 0em; |
|
margin-inline-start: 0px; |
|
margin-inline-end: 0px; |
|
font-weight: bold; |
|
} |
|
|
|
#chatbot h4 { |
|
margin-block-start: 1.33em; |
|
margin-block-end: 0em; |
|
margin-inline-start: 0px; |
|
margin-inline-end: 0px; |
|
font-weight: bold; |
|
} |
|
|
|
#chatbot h5 { |
|
margin-block-start: 1.67em; |
|
margin-block-end: 0em; |
|
margin-inline-start: 0px; |
|
margin-inline-end: 0px; |
|
font-weight: bold; |
|
} |
|
|
|
#chatbot h6 { |
|
margin-block-start: 1.83em; |
|
margin-block-end: 0em; |
|
margin-inline-start: 0px; |
|
margin-inline-end: 0px; |
|
font-weight: bold; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gallery-item > .gallery { |
|
max-width: 380px; |
|
} |
|
|
|
#context > label > textarea { |
|
scrollbar-width: thin !important; |
|
} |
|
|
|
#cost_info { |
|
border-style: none !important; |
|
} |
|
|
|
#cost_info > label > input { |
|
background: var(--panel-background-fill) !important; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<gradio-lite> |
|
<gradio-requirements> |
|
pdfminer.six==20231228 |
|
pyodide-http==0.2.1 |
|
janome==0.5.0 |
|
rank_bm25==0.2.2 |
|
</gradio-requirements> |
|
|
|
<gradio-file name="chat_history.json"> |
|
[[null, "ようこそ! PDFのテキストを参照しながら対話できるチャットボットです。\nPDFファイルをアップロードするとテキストが抽出されます。\nメッセージの中に{context}と書くと、抽出されたテキストがその部分に埋め込まれて対話が行われます。他にもPDFのページを検索して参照したり、ページ番号を指定して参照したりすることができます。一番下のExamplesにこれらの例があります。\nメッセージを書くときにShift+Enterを入力すると改行できます。"]] |
|
</gradio-file> |
|
|
|
<gradio-file name="app.py" entrypoint> |
|
import os |
|
|
|
# Gradioによるアナリティクスを無効化 |
|
os.putenv("GRADIO_ANALYTICS_ENABLED", "False") |
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
|
|
# openaiライブラリのインストール方法は https://github.com/pyodide/pyodide/issues/4292 を参考にしました。 |
|
import micropip |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/multidict/multidict-4.7.6-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/frozenlist/frozenlist-1.4.0-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/aiohttp/aiohttp-4.0.0a2.dev0-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/openai/openai-1.3.7-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/urllib3/urllib3-2.1.0-py3-none-any.whl", keep_going=True) |
|
await micropip.install("ssl") |
|
import ssl |
|
await micropip.install("httpx", keep_going=True) |
|
import httpx |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/urllib3/urllib3-2.1.0-py3-none-any.whl", keep_going=True) |
|
import urllib3 |
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.2-cp311-cp311-emscripten_3_1_46_wasm32.whl", keep_going=True) |
|
|
|
|
|
import gradio as gr |
|
import base64 |
|
import json |
|
import unicodedata |
|
import re |
|
from pathlib import Path |
|
from dataclasses import dataclass |
|
import asyncio |
|
|
|
import pyodide_http |
|
pyodide_http.patch_all() |
|
|
|
from pdfminer.pdfinterp import PDFResourceManager |
|
from pdfminer.converter import TextConverter |
|
from pdfminer.pdfinterp import PDFPageInterpreter |
|
from pdfminer.pdfpage import PDFPage |
|
from pdfminer.layout import LAParams |
|
from io import StringIO |
|
|
|
from janome.tokenizer import Tokenizer as JanomeTokenizer |
|
from janome.analyzer import Analyzer as JanomeAnalyzer |
|
from janome.tokenfilter import POSStopFilter, LowerCaseFilter |
|
from rank_bm25 import BM25Okapi |
|
|
|
from openai import OpenAI, AzureOpenAI |
|
import tiktoken |
|
import requests |
|
|
|
|
|
class URLLib3Transport(httpx.BaseTransport): |
|
""" |
|
urllib3を使用してhttpxのリクエストを処理するカスタムトランスポートクラス |
|
""" |
|
def __init__(self): |
|
self.pool = urllib3.PoolManager() |
|
|
|
def handle_request(self, request: httpx.Request): |
|
payload = json.loads(request.content.decode("utf-8")) |
|
urllib3_response = self.pool.request(request.method, str(request.url), headers=request.headers, json=payload) |
|
stream = httpx.ByteStream(urllib3_response.data) |
|
return httpx.Response(urllib3_response.status, headers=urllib3_response.headers, stream=stream) |
|
|
|
http_client = httpx.Client(transport=URLLib3Transport()) |
|
|
|
|
|
@dataclass |
|
class Page: |
|
""" |
|
PDFのページ内容 |
|
""" |
|
number: int |
|
content: str |
|
|
|
|
|
def load_tiktoken_model(model_url): |
|
resp = requests.get(model_url) |
|
resp.raise_for_status() |
|
return resp.content |
|
|
|
# OPENAI_TOKENIZER = tiktoken.get_encoding("cl100k_base") |
|
OPENAI_TOKENIZER = tiktoken.Encoding( |
|
name="cl100k_base", |
|
pat_str=r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""", |
|
mergeable_ranks={ |
|
base64.b64decode(token): int(rank) |
|
for token, rank in (line.split() for line in load_tiktoken_model("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/cl100k_base.tiktoken").splitlines() if line) |
|
}, |
|
special_tokens={ |
|
"<|endoftext|>": 100257, |
|
"<|fim_prefix|>": 100258, |
|
"<|fim_middle|>": 100259, |
|
"<|fim_suffix|>": 100260, |
|
"<|endofprompt|>": 100276, |
|
} |
|
) |
|
|
|
JANOME_TOKENIZER = JanomeTokenizer() |
|
JANOME_ANALYZER = JanomeAnalyzer(tokenizer=JANOME_TOKENIZER, |
|
token_filters=[POSStopFilter(["記号,空白"]), LowerCaseFilter()]) |
|
|
|
|
|
def extract_pdf_pages(pdf_filename): |
|
""" |
|
PDFファイルからテキストを抽出する。 |
|
|
|
Args: |
|
pdf_filename (str): 抽出するPDFファイルのパス |
|
|
|
Returns: |
|
list[Page]: PDFの各ページ内容のリスト |
|
""" |
|
|
|
pages = [] |
|
with open(pdf_filename, "rb") as pdf_file: |
|
output = StringIO() |
|
resource_manager = PDFResourceManager() |
|
laparams = LAParams() |
|
text_converter = TextConverter(resource_manager, output, laparams=laparams) |
|
page_interpreter = PDFPageInterpreter(resource_manager, text_converter) |
|
|
|
page_number = 0 |
|
for i_page in PDFPage.get_pages(pdf_file): |
|
try: |
|
page_number += 1 |
|
page_interpreter.process_page(i_page) |
|
page_content = output.getvalue() |
|
page_content = unicodedata.normalize('NFKC', page_content) |
|
pages.append(Page(number=page_number, content=page_content)) |
|
output.truncate(0) |
|
output.seek(0) |
|
except Exception as e: |
|
print(e) |
|
pass |
|
|
|
output.close() |
|
text_converter.close() |
|
|
|
return pages |
|
|
|
|
|
def merge_pages_with_page_tag(pages): |
|
""" |
|
PDFの各ページ内容を一つの文字列にマージする。 |
|
ただし、chatpdf:pageというタグでページを括る。 |
|
extract_pages_from_page_tag()の逆変換である。 |
|
|
|
Args: |
|
pages (list[Page]): PDFの各ページ内容のリスト |
|
|
|
Returns: |
|
str: PDFの各ページ内容をマージした文字列 |
|
""" |
|
document_with_page_tag = "" |
|
for page in pages: |
|
document_with_page_tag += f'<chatpdf:page number="{page.number}">\n{page.content}\n</chatpdf:page>\n' |
|
return document_with_page_tag |
|
|
|
|
|
def extract_pages_from_page_tag(document_with_page_tag): |
|
""" |
|
chatpdf:pageというタグで括られた領域をPDFのページ内容と解釈して、Pageオブジェクトのリストに変換する。 |
|
merge_pages_with_page_tag()の逆変換である。 |
|
|
|
Args: |
|
document_with_page_tag (str): chatpdf:pageというタグで各ページが括られた文字列 |
|
|
|
Returns: |
|
list[Page]: Pageオブジェクトのリスト |
|
""" |
|
page_tag_pattern = r'<chatpdf:page number="(\d+)">\n?(.*?)\n?<\/chatpdf:page>\n?' |
|
matches = re.findall(page_tag_pattern, document_with_page_tag, re.DOTALL) |
|
pages = [Page(number=int(number), content=content) for number, content in matches] |
|
return pages |
|
|
|
|
|
def escape_latex(unescaped_text): |
|
""" |
|
Chatbotのmarkdownで数式が表示されるように \\(, \\), \\[, \\] をバックスラッシュでエスケープする。 |
|
|
|
Args: |
|
unescaped_text (str): エスケープ対象文字列 |
|
|
|
Returns: |
|
str: エスケープされた文字列 |
|
""" |
|
return re.sub(r"(\\[\(\)\[\]])", r"\\\1", unescaped_text) |
|
|
|
|
|
def unescape_latex(escaped_text): |
|
""" |
|
Chatbotのmarkdownで数式が表示されるようにエスケープされていた \\(, \\), \\[, \\] をエスケープされていない元の括弧に変換する。 |
|
|
|
Args: |
|
escaped_text (str): エスケープされた文字列 |
|
|
|
Returns: |
|
str: エスケープされていない文字列 |
|
""" |
|
return re.sub(r"\\(\\[\(\)\[\]])", r"\1", escaped_text) |
|
|
|
|
|
def add_s(values): |
|
""" |
|
複数形のsを必要に応じて付けるために用いる関数。 |
|
与えられたリストの要素数が2以上なら"s"を返し、それ以外は""を返す。 |
|
|
|
Args: |
|
values (list[any]): リスト |
|
|
|
Returns: |
|
str: 要素数が複数なら"s"、それ以外は"" |
|
""" |
|
return "s" if len(values) > 1 else "" |
|
|
|
|
|
def get_context_info(characters, tokens): |
|
""" |
|
文字数とトークン数の情報を文字列で返す。 |
|
|
|
Args: |
|
characters (str): テキスト |
|
tokens (list[str]): トークン |
|
|
|
Returns: |
|
str: 文字数とトークン数の情報を含む文字列 |
|
""" |
|
char_count = len(characters) |
|
token_count = len(tokens) |
|
return f"{char_count:,} character{add_s(characters)}\n{token_count:,} token{add_s(tokens)}" |
|
|
|
|
|
def update_context_element(pdf_file_obj): |
|
""" |
|
PDFファイルからテキストを抽出し、コンテキスト要素を更新する。 |
|
|
|
Args: |
|
pdf_file_obj (File): アップロードされたPDFファイルオブジェクト |
|
|
|
Returns: |
|
Tuple: コンテキストテキストボックスに格納する抽出されたテキスト情報と、その文字数情報 |
|
""" |
|
pages = extract_pdf_pages(pdf_file_obj.name) |
|
document_with_tag = merge_pages_with_page_tag(pages) |
|
return gr.update(value=document_with_tag, interactive=True), count_characters(document_with_tag) |
|
|
|
|
|
def count_characters(document_with_tag): |
|
""" |
|
テキストの文字数とトークン数を計算する。 |
|
ただし、テキストはchatpdf:pageというタグでページが括られているとする。 |
|
|
|
Args: |
|
document_with_tag (str): 文字数とトークン数を計算するテキスト |
|
|
|
Returns: |
|
str: 文字数とトークン数の情報を含む文字列 |
|
""" |
|
text = "".join([page.content for page in extract_pages_from_page_tag(document_with_tag)]) |
|
|
|
tokens = OPENAI_TOKENIZER.encode(text) |
|
return get_context_info(text, tokens) |
|
|
|
|
|
class SearchEngine: |
|
""" |
|
検索エンジン |
|
""" |
|
def __init__(self, engine, pages): |
|
self.engine = engine |
|
self.pages = pages |
|
|
|
|
|
SEARCH_ENGINE = None |
|
|
|
def create_search_engine(context): |
|
""" |
|
検索エンジンを作る。 |
|
|
|
Args: |
|
context (str): 検索対象となるテキスト。ただし、テキストはchatpdf:pageというタグでページが括られているとする。 |
|
""" |
|
global SEARCH_ENGINE |
|
|
|
pages = extract_pages_from_page_tag(context) |
|
tokenized_pages = [] |
|
original_pages = [] |
|
for page in pages: |
|
page_content = page.content.strip() |
|
if page_content: |
|
tokenized_page = [token.base_form for token in JANOME_ANALYZER.analyze(page_content)] |
|
if tokenized_page: |
|
tokenized_pages.append(tokenized_page) |
|
original_pages.append(page) |
|
|
|
if tokenized_pages: |
|
bm25 = BM25Okapi(tokenized_pages) |
|
SEARCH_ENGINE = SearchEngine(engine=bm25, pages=original_pages) |
|
else: |
|
SEARCH_ENGINE = None |
|
|
|
|
|
def search_pages(keywords, page_limit): |
|
""" |
|
与えられたキーワードを含むページを検索する。 |
|
|
|
Args: |
|
keywords (str): 検索キーワード |
|
page_limit (int): 検索するページ数 |
|
|
|
Returns: |
|
list[Page]: ヒットしたページ |
|
""" |
|
global SEARCH_ENGINE |
|
if SEARCH_ENGINE is None: |
|
return [] |
|
|
|
tokenized_query = [token.base_form for token in JANOME_ANALYZER.analyze(keywords)] |
|
if not tokenized_query: |
|
return [] |
|
|
|
found_pages = SEARCH_ENGINE.engine.get_top_n(tokenized_query, SEARCH_ENGINE.pages, n=page_limit) |
|
return found_pages |
|
|
|
|
|
def load_pages(page_numbers): |
|
""" |
|
与えられたページ番号のページを取得する。 |
|
|
|
Args: |
|
page_numbers (list[int]): 取得するページ番号 |
|
|
|
Returns: |
|
list[Page]: 取得したページ |
|
""" |
|
global SEARCH_ENGINE |
|
if SEARCH_ENGINE is None: |
|
return [] |
|
|
|
page_numbers = set(page_numbers) |
|
found_pages = [page for page in SEARCH_ENGINE.pages if page.number in page_numbers] |
|
return found_pages |
|
|
|
|
|
# function calling用ツール |
|
CHAT_TOOLS = [ |
|
# ページ検索 |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "search_pages", |
|
"description": "Searches for pages containing the given keywords.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"keywords": { |
|
"type": "string", |
|
"description": 'Search keywords separated by spaces. For example, "Artificial General Intelligence 自律エージェント".' |
|
}, |
|
"page_limit": { |
|
"type": "number", |
|
"description": "Maximum number of search results to return. For example, 3.", |
|
"minimum": 1 |
|
} |
|
} |
|
}, |
|
"required": ["keywords"] |
|
} |
|
}, |
|
# ページ取得 |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "load_pages", |
|
"description": "Loads pages specified by their page numbers.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"page_numbers": { |
|
"type": "array", |
|
"items": { |
|
"type": "number" |
|
}, |
|
"description": "List of page numbers to be load", |
|
"minItems": 1 |
|
} |
|
} |
|
}, |
|
"required": ["page_numbers"] |
|
} |
|
} |
|
] |
|
|
|
# function callingなど、固定で消費するトークン数 |
|
CHAT_TOOLS_TOKENS = 139 |
|
|
|
|
|
def get_openai_messages(prompt, history, context): |
|
""" |
|
与えられた対話用データを、ChatGPT APIの入力に用いられるメッセージデータ形式に変換して返す。 |
|
|
|
Args: |
|
prompt (str): ユーザーからの入力プロンプト |
|
history (list[list[str]]): チャット履歴 |
|
context (str): チャットコンテキスト |
|
|
|
Returns: |
|
str: ChatGPT APIの入力に用いられるメッセージデータ |
|
""" |
|
global SEARCH_ENGINE |
|
if SEARCH_ENGINE is not None: |
|
context = "".join([page.content for page in SEARCH_ENGINE.pages]) |
|
|
|
messages = [] |
|
for user_message, assistant_message in history: |
|
if user_message is not None and assistant_message is not None: |
|
user_message = unescape_latex(user_message) |
|
user_message = user_message.replace("{context}", context) |
|
assistant_message = unescape_latex(assistant_message) |
|
messages.append({ "role": "user", "content": user_message }) |
|
messages.append({ "role": "assistant", "content": assistant_message }) |
|
|
|
prompt = prompt.replace("{context}", context) |
|
messages.append({ "role": "user", "content": prompt }) |
|
|
|
return messages |
|
|
|
|
|
# それまでの全入力トークン数 |
|
actual_total_cost_prompt = 0 |
|
|
|
# それまでの全出力トークン数 |
|
actual_total_cost_completion = 0 |
|
|
|
|
|
async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag): |
|
""" |
|
ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。 |
|
|
|
Args: |
|
prompt (str): ユーザーからの入力プロンプト |
|
history (list[list[str]]): チャット履歴 |
|
context (str): チャットコンテキスト |
|
platform (str): 使用するAIプラットフォーム |
|
endpoint (str): AIサービスのエンドポイント |
|
azure_deployment (str): Azureのデプロイメント名 |
|
azure_api_version (str): Azure APIのバージョン |
|
api_key (str): APIキー |
|
model_name (str): 使用するAIモデルの名前 |
|
max_tokens (int): 生成する最大トークン数 |
|
temperature (float): クリエイティビティの度合いを示す温度パラメータ |
|
enable_rag (bool): RAG機能を有効にするかどうか |
|
|
|
Returns: |
|
str: ChatGPTによる生成結果 |
|
""" |
|
global actual_total_cost_prompt, actual_total_cost_completion |
|
|
|
try: |
|
messages = get_openai_messages(prompt, history, context) |
|
|
|
if platform == "OpenAI": |
|
openai_client = OpenAI( |
|
base_url=endpoint, |
|
api_key=api_key, |
|
http_client=http_client |
|
) |
|
else: # Azure |
|
openai_client = AzureOpenAI( |
|
azure_endpoint=endpoint, |
|
api_version=azure_api_version, |
|
azure_deployment=azure_deployment, |
|
api_key=api_key, |
|
http_client=http_client |
|
) |
|
|
|
if enable_rag: |
|
completion = openai_client.chat.completions.create( |
|
messages=messages, |
|
model=model_name, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
tools=CHAT_TOOLS, |
|
tool_choice="auto", |
|
stream=False |
|
) |
|
else: |
|
completion = openai_client.chat.completions.create( |
|
messages=messages, |
|
model=model_name, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
stream=False |
|
) |
|
|
|
bot_response = "" |
|
if hasattr(completion, "error"): |
|
raise gr.Error(completion.error["message"]) |
|
|
|
response_message = completion.choices[0].message |
|
tool_calls = response_message.tool_calls |
|
actual_total_cost_prompt += completion.usage.prompt_tokens |
|
actual_total_cost_completion += completion.usage.completion_tokens |
|
|
|
if tool_calls: |
|
messages.append(response_message) |
|
|
|
for tool_call in tool_calls: |
|
function_name = tool_call.function.name |
|
function_args = json.loads(tool_call.function.arguments) |
|
if function_name == "search_pages": |
|
# ページ検索 |
|
keywords = function_args.get("keywords").strip() |
|
page_limit = function_args.get("page_limit") or 3 |
|
|
|
bot_response += f'Searching for pages containing the keyword{add_s(keywords.split(" "))} "{keywords}".\n' |
|
|
|
found_pages = search_pages(keywords, page_limit) |
|
function_response = json.dumps({ |
|
"status": "found" if found_pages else "not found", |
|
"found_pages": [{ |
|
"page_number": page.number, |
|
"page_content": page.content |
|
} for page in found_pages] |
|
}, ensure_ascii=False) |
|
messages.append({ |
|
"tool_call_id": tool_call.id, |
|
"role": "tool", |
|
"name": function_name, |
|
"content": function_response |
|
}) |
|
if found_pages: |
|
bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n' |
|
else: |
|
bot_response += "Page not found.\n\n" |
|
|
|
elif function_name == "load_pages": |
|
# ページ取得 |
|
page_numbers = function_args.get("page_numbers") |
|
|
|
bot_response += f'Trying to load page{add_s(page_numbers)} {", ".join(map(str, page_numbers))}.\n' |
|
|
|
found_pages = load_pages(page_numbers) |
|
function_response = json.dumps({ |
|
"status": "found" if found_pages else "not found", |
|
"found_pages": [{ |
|
"page_number": page.number, |
|
"page_content": page.content |
|
} for page in found_pages] |
|
}, ensure_ascii=False) |
|
messages.append({ |
|
"tool_call_id": tool_call.id, |
|
"role": "tool", |
|
"name": function_name, |
|
"content": function_response |
|
}) |
|
if found_pages: |
|
bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n' |
|
else: |
|
bot_response += "Page not found.\n\n" |
|
else: |
|
raise gr.Error(f"Unknown function calling '{function_name}'.") |
|
|
|
yield bot_response + "Generating response. Please wait a moment...\n" |
|
await asyncio.sleep(0.1) |
|
|
|
completion = openai_client.chat.completions.create( |
|
messages=messages, |
|
model=model_name, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
stream=False |
|
) |
|
actual_total_cost_prompt += completion.usage.prompt_tokens |
|
actual_total_cost_completion += completion.usage.completion_tokens |
|
|
|
if hasattr(completion, "error"): |
|
raise gr.Error(completion.error["message"]) |
|
|
|
response_message = completion.choices[0].message |
|
bot_response += response_message.content |
|
yield bot_response |
|
|
|
else: |
|
bot_response += response_message.content |
|
yield bot_response |
|
|
|
except Exception as e: |
|
if hasattr(e, "message"): |
|
raise gr.Error(e.message) |
|
else: |
|
raise gr.Error(str(e)) |
|
|
|
|
|
def load_api_key(file_obj): |
|
""" |
|
APIキーファイルからAPIキーを読み込む。 |
|
|
|
Args: |
|
file_obj (File): APIキーファイルオブジェクト |
|
|
|
Returns: |
|
str: 読み込まれたAPIキー文字列 |
|
""" |
|
try: |
|
with open(file_obj.name, "r", encoding="utf-8") as api_key_file: |
|
return api_key_file.read().strip() |
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|
|
|
|
def get_cost_info(prompt_token_count): |
|
""" |
|
チャットのトークン数情報を表示するための文字列を返す。 |
|
|
|
Args: |
|
prompt_token_count (int): プロンプト(履歴込み)のトークン数 |
|
|
|
Returns: |
|
str: チャットのトークン数情報を表示するための文字列 |
|
""" |
|
return f"Estimated input cost: {prompt_token_count + CHAT_TOOLS_TOKENS:,} tokens, Actual total input cost: {actual_total_cost_prompt:,} tokens, Actual total output cost: {actual_total_cost_completion:,} tokens" |
|
|
|
|
|
# デフォルト設定値 |
|
DEFAULT_SETTINGS = { |
|
"setting_name": "Default", |
|
"platform": "OpenAI", |
|
"endpoint": "https://api.openai.com/v1", |
|
"azure_deployment": "", |
|
"azure_api_version": "", |
|
"model_name": "gpt-4-turbo-preview", |
|
"max_tokens": 4096, |
|
"temperature": 0.2, |
|
"enable_rag": True, |
|
"save_chat_history_to_url": False |
|
}; |
|
|
|
|
|
def main(): |
|
""" |
|
アプリケーションのメイン関数。Gradioインターフェースを設定し、アプリケーションを起動する。 |
|
""" |
|
try: |
|
# クエリパラメータに保存されていることもあるチャット履歴を読み出す。 |
|
with open("chat_history.json", "r", encoding="utf-8") as f: |
|
CHAT_HISTORY = json.load(f) |
|
except Exception as e: |
|
print(e) |
|
CHAT_HISTORY = [] |
|
|
|
# localStorageから設定情報ををロードする。 |
|
js_define_utilities_and_load_settings = """() => { |
|
const KEY_PREFIX = "serverless_chat_with_your_pdf:"; |
|
|
|
const loadSettings = () => { |
|
const getItem = (key, defaultValue) => { |
|
const jsonValue = localStorage.getItem(KEY_PREFIX + key); |
|
if (jsonValue) { |
|
return JSON.parse(jsonValue); |
|
} else { |
|
return defaultValue; |
|
} |
|
}; |
|
""" + "".join([f""" |
|
const default_{setting_key} = {json.dumps(default_value, ensure_ascii=False)}; |
|
const {setting_key} = getItem("{setting_key}", default_{setting_key}); |
|
""" for setting_key, default_value in DEFAULT_SETTINGS.items()]) + """ |
|
const serialized_saved_settings = getItem("saved_settings", []); |
|
const default_saved_settings = [[ |
|
""" + ", ".join([f"{json.dumps(default_value, ensure_ascii=False)}" for _, default_value in DEFAULT_SETTINGS.items()]) + """ |
|
]]; |
|
|
|
saved_settings = []; |
|
for (let entry of serialized_saved_settings) { |
|
saved_settings.push([ |
|
entry["setting_name"] || "", |
|
entry["platform"] || default_platform, |
|
entry["endpoint"] || default_endpoint, |
|
entry["azure_deployment"] || default_azure_deployment, |
|
entry["azure_api_version"] || default_azure_api_version, |
|
entry["model_name"] || default_model_name, |
|
entry["max_tokens"] || default_max_tokens, |
|
entry["temperature"] || default_temperature, |
|
entry["enable_rag"] || default_enable_rag, |
|
entry["save_chat_history_to_url"] || default_save_chat_history_to_url |
|
]); |
|
} |
|
if (saved_settings.length == 0) { |
|
saved_settings = default_saved_settings; |
|
} |
|
|
|
return [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, saved_settings]; |
|
}; |
|
|
|
globalThis.resetSettings = () => { |
|
for (let key in localStorage) { |
|
if (key.startsWith(KEY_PREFIX) && !key.startsWith(KEY_PREFIX + "saved_settings") && !key.startsWith(KEY_PREFIX + "setting_name")) { |
|
localStorage.removeItem(key); |
|
} |
|
} |
|
|
|
return loadSettings(); |
|
}; |
|
|
|
globalThis.saveItem = (key, value) => { |
|
localStorage.setItem(KEY_PREFIX + key, JSON.stringify(value)); |
|
}; |
|
|
|
return loadSettings(); |
|
} |
|
""" |
|
|
|
# should_saveがtrueであればURLにチャット履歴を保存し、falseであればチャット履歴を削除する。 |
|
save_or_delete_chat_history = '''(hist, should_save) => { |
|
saveItem("save_chat_history_to_url", should_save); |
|
if (!should_save) { |
|
const url = new URL(window.location.href); |
|
url.searchParams.delete("history"); |
|
window.history.replaceState({path:url.href}, '', url.href); |
|
} else { |
|
const compressedHistory = LZString.compressToEncodedURIComponent(JSON.stringify(hist)); |
|
const url = new URL(window.location.href); |
|
url.searchParams.set("history", compressedHistory); |
|
window.history.replaceState({path:url.href}, '', url.href); |
|
} |
|
}''' |
|
|
|
# メッセージ例 |
|
examples = { |
|
"要約 (論文)": '''制約条件に従い、以下の研究論文で提案されている技術や手法について要約してください。 |
|
|
|
# 制約条件 |
|
* 要約者: 大学教授 |
|
* 想定読者: 大学院生 |
|
* 要約結果の言語: 日本語 |
|
* 要約結果の構成(以下の各項目について500文字): |
|
1. どんな研究であるか |
|
2. 先行研究に比べて優れている点は何か |
|
3. 提案されている技術や手法の重要な点は何か |
|
4. どのような方法で有効であると評価したか |
|
5. 何か議論はあるか |
|
6. 次に読むべき論文は何か |
|
|
|
# 研究論文 |
|
""" |
|
{context} |
|
""" |
|
|
|
# 要約結果''', |
|
"要約 (一般)": '''制約条件に従い、以下の文書の内容を要約してください。 |
|
|
|
# 制約条件 |
|
* 要約者: 技術コンサルタント |
|
* 想定読者: 経営層、CTO、CIO |
|
* 形式: 箇条書き |
|
* 分量: 20項目 |
|
* 要約結果の言語: 日本語 |
|
|
|
# 文書 |
|
""" |
|
{context} |
|
""" |
|
|
|
# 要約''', |
|
"情報抽出": '''制約条件に従い、以下の文書から情報を抽出してください。 |
|
|
|
# 制約条件 |
|
* 抽出する情報: 課題や問題点について言及している全ての文。一つも見落とさないでください。 |
|
* 出力形式: 箇条書き |
|
* 出力言語: 元の言語の文章と、その日本語訳 |
|
|
|
# 文書 |
|
""" |
|
{context} |
|
""" |
|
|
|
# 抽出結果''', |
|
"QA (日本語文書RAG)": '''次の質問に回答するために役立つページを検索して、その検索結果を使って回答して下さい。 |
|
|
|
# 制約条件 |
|
* 検索クエリの生成方法: 質問文の3つの言い換え(paraphrase)をカンマ区切りで連結した文字列 |
|
* 検索クエリの言語: 日本語 |
|
* 検索するページ数: 3 |
|
* 回答方法: |
|
- 検索結果の情報のみを用いて回答すること。 |
|
- 回答に利用した文章のあるページ番号を、それぞれの回答文の文末に付与すること。形式: "(参考ページ番号: 71, 59, 47)" |
|
- 回答に役立つ情報が検索結果内にない場合は「検索結果には回答に役立つ情報がありませんでした。」と回答すること。 |
|
* 回答の言語: 日本語 |
|
|
|
# 質問 |
|
どのような方法で、提案された手法が有効であると評価しましたか? |
|
|
|
# 回答''', |
|
"QA (英語文書RAG)": '''次の質問に回答するために役立つページを検索して、その検索結果を使って回答して下さい。 |
|
|
|
# 制約条件 |
|
* 検索クエリの生成方法: 質問文の3つの言い換え(paraphrase)をカンマ区切りで連結した文字列 |
|
* 検索クエリの言語: 英語 |
|
* 検索するページ数: 3 |
|
* 回答方法: |
|
- 検索結果の情報のみを用いて回答すること。 |
|
- 回答に利用した文章のあるページ番号を、それぞれの回答文の文末に付与すること。形式: "(参考ページ番号: 71, 59, 47)" |
|
- 回答に役立つ情報が検索結果内にない場合は「検索結果には回答に役立つ情報がありませんでした。」と回答すること。 |
|
* 回答の言語: 日本語 |
|
|
|
# 質問 |
|
どのような方法で、提案された手法が有効であると評価しましたか? |
|
|
|
# 回答''', |
|
"要約 (RAG)": '''次のキーワードを含むページを検索して、その検索結果をページごとに要約して下さい。 |
|
|
|
# 制約条件 |
|
* キーワード: dataset datasets |
|
* 検索するページ数: 3 |
|
* 要約結果の言語: 日本語 |
|
* 要約の形式: |
|
## ページ番号(例: 12ページ) |
|
- 要約文1 |
|
- 要約文2 |
|
... |
|
* 要約の分量: 各ページ3項目 |
|
|
|
# 要約''', |
|
"翻訳 (RAG)": '''次のキーワードを含むページを検索して、その検索結果を日本語に翻訳して下さい。 |
|
|
|
# 制約条件 |
|
* キーワード: dataset datasets |
|
* 検索するページ数: 1 |
|
|
|
# 翻訳結果''', |
|
"要約 (ページ指定)": '''16〜17ページをページごとに箇条書きで要約して下さい。 |
|
|
|
# 制約条件 |
|
* 要約結果の言語: 日本語 |
|
* 要約の形式: |
|
## ページ番号(例: 12ページ) |
|
- 要約文1 |
|
- 要約文2 |
|
... |
|
* 要約の分量: 各ページ5項目 |
|
|
|
# 要約''', |
|
"続きを生成": "続きを生成してください。" |
|
} |
|
|
|
with gr.Blocks(theme=gr.themes.Default(), analytics_enabled=False) as app: |
|
with gr.Tabs(): |
|
with gr.TabItem("Settings"): |
|
with gr.Column(): |
|
with gr.Column(variant="panel"): |
|
with gr.Row(): |
|
setting_name = gr.Textbox(label="Setting Name", value="Default", interactive=True) |
|
setting_name.change(None, inputs=setting_name, outputs=None, |
|
js='(x) => saveItem("setting_name", x)', show_progress="hidden") |
|
|
|
with gr.Row(): |
|
platform = gr.Radio(label="Platform", interactive=True, |
|
choices=["OpenAI", "Azure"], value="OpenAI") |
|
platform.change(None, inputs=platform, outputs=None, |
|
js='(x) => saveItem("platform", x)', show_progress="hidden") |
|
|
|
with gr.Row(): |
|
endpoint = gr.Textbox(label="Endpoint", interactive=True) |
|
endpoint.change(None, inputs=endpoint, outputs=None, |
|
js='(x) => saveItem("endpoint", x)', show_progress="hidden") |
|
|
|
azure_deployment = gr.Textbox(label="Azure Deployment", interactive=True) |
|
azure_deployment.change(None, inputs=azure_deployment, outputs=None, |
|
js='(x) => saveItem("azure_deployment", x)', show_progress="hidden") |
|
|
|
azure_api_version = gr.Textbox(label="Azure API Version", interactive=True) |
|
azure_api_version.change(None, inputs=azure_api_version, outputs=None, |
|
js='(x) => saveItem("azure_api_version", x)', show_progress="hidden") |
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
api_key_file = gr.File(file_count="single", file_types=["text"], |
|
height=80, label="API Key File") |
|
api_key = gr.Textbox(label="API Key", type="password", interactive=True) |
|
# 注意: 秘密情報をlocalStorageに保存してはならない。他者に秘密情報が盗まれる危険性があるからである。 |
|
|
|
api_key_file.upload(load_api_key, inputs=api_key_file, outputs=api_key, |
|
show_progress="hidden") |
|
api_key_file.clear(lambda: None, inputs=None, outputs=api_key, show_progress="hidden") |
|
|
|
with gr.Row(): |
|
model_name = gr.Textbox(label="Model", interactive=True) |
|
model_name.change(None, inputs=model_name, outputs=None, |
|
js='(x) => saveItem("model_name", x)', show_progress="hidden") |
|
|
|
max_tokens = gr.Number(label="Max Tokens", interactive=True, |
|
minimum=0, precision=0, step=1) |
|
max_tokens.change(None, inputs=max_tokens, outputs=None, |
|
js='(x) => saveItem("max_tokens", x)', show_progress="hidden") |
|
|
|
temperature = gr.Slider(label="Temperature", interactive=True, |
|
minimum=0.0, maximum=1.0, step=0.1) |
|
temperature.change(None, inputs=temperature, outputs=None, |
|
js='(x) => saveItem("temperature", x)', show_progress="hidden") |
|
|
|
enable_rag = gr.Checkbox(label="Enable RAG (Retrieval Augmented Generation)", interactive=True) |
|
enable_rag.change(None, inputs=enable_rag, outputs=None, |
|
js='(x) => saveItem("enable_rag", x)', show_progress="hidden") |
|
|
|
save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True) |
|
|
|
reset_button = gr.Button("Reset Settings") |
|
|
|
with gr.Column(variant="panel"): |
|
default_saved_settings = list(DEFAULT_SETTINGS.values()) |
|
|
|
saved_settings_df = gr.Dataframe( |
|
elem_id="saved_settings", |
|
value=[default_saved_settings], |
|
headers=["Name", "Platform", "Endpoint", "Azure Deployment", "Azure API Version", "Model", "Max Tokens", "Temperature", "Enable RAG", "Save Chat History to URL"], |
|
row_count=(0, "dynamic"), |
|
col_count=(10, "fixed"), |
|
datatype=["str", "str", "str", "str", "str", "str", "number", "number", "bool", "bool"], |
|
type="array", |
|
label="Saved Settings", |
|
show_label=True, |
|
interactive=False |
|
) |
|
selected_setting = gr.State(None) |
|
temp_selected_row_index = gr.JSON(value=None, visible=False) |
|
|
|
|
|
def select_setting(event: gr.SelectData): |
|
return (event.index[0], event.index[1]), event.index[0] |
|
|
|
|
|
saved_settings_df.select( |
|
select_setting, inputs=None, outputs=[selected_setting, temp_selected_row_index], queue=False, show_progress="hidden" |
|
).then( |
|
None, inputs=temp_selected_row_index, outputs=None, js='(row_index) => { for (let e of document.querySelectorAll("#saved_settings > div > div > button > svelte-virtual-table-viewport > table > tbody > tr")[row_index].children) { e.classList.add("focus"); } }', queue=False, show_progress="hidden" |
|
) |
|
|
|
with gr.Row(): |
|
load_saved_settings_button = gr.Button("Load") |
|
append_or_overwrite_saved_settings_button = gr.Button("Append or Overwrite") |
|
delete_saved_settings_button = gr.Button("Delete") |
|
|
|
serialized_saved_settings_state = gr.JSON(visible=False) |
|
|
|
|
|
def load_saved_setting(saved_settings, selected_setting): |
|
if not selected_setting: |
|
return saved_settings |
|
|
|
def u(x): |
|
return gr.update(value=x, interactive=True) |
|
|
|
row_index = selected_setting[0] |
|
|
|
setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url = saved_settings[row_index] |
|
|
|
return u(setting_name), u(platform), u(endpoint), u(azure_deployment), u(azure_api_version), u(model_name), u(max_tokens), u(temperature), u(enable_rag), u(save_chat_history_to_url), None |
|
|
|
|
|
load_saved_settings_button.click(load_saved_setting, inputs=[saved_settings_df, selected_setting], outputs=[setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, selected_setting], queue=False, show_progress="hidden") |
|
|
|
|
|
def append_or_overwrite_setting(saved_settings, setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url): |
|
|
|
setting_name = setting_name.strip() |
|
|
|
found = False |
|
new_saved_settings = [] |
|
for entry in saved_settings: |
|
if entry[0] == setting_name: |
|
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url]) |
|
found = True |
|
else: |
|
new_saved_settings.append(entry) |
|
|
|
if not found: |
|
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url]) |
|
|
|
return new_saved_settings, None |
|
|
|
|
|
def serialize_saved_settings(saved_settings): |
|
serialization_keys = list(DEFAULT_SETTINGS.keys()) |
|
|
|
serialized_saved_settings = [ |
|
{ k: entry[i] for i, k in enumerate(serialization_keys) } |
|
for entry in saved_settings |
|
] |
|
return serialized_saved_settings |
|
|
|
|
|
append_or_overwrite_saved_settings_button.click( |
|
append_or_overwrite_setting, inputs=[saved_settings_df, setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden" |
|
).then( |
|
serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden", |
|
).then( |
|
None, inputs=serialized_saved_settings_state, outputs=None, js='(x) => saveItem("saved_settings", x)', queue=False, show_progress="hidden" |
|
) |
|
|
|
|
|
def delete_setting(saved_settings, selected_setting): |
|
if not selected_setting: |
|
return saved_settings |
|
|
|
row_index = selected_setting[0] |
|
new_saved_settings = saved_settings[0:row_index] + saved_settings[row_index + 1:] |
|
|
|
if not new_saved_settings: |
|
new_saved_settings.append(default_saved_settings) |
|
|
|
return new_saved_settings, None |
|
|
|
|
|
delete_saved_settings_button.click( |
|
delete_setting, inputs=[saved_settings_df, selected_setting], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden" |
|
).then( |
|
serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden", |
|
).then( |
|
None, inputs=serialized_saved_settings_state, outputs=None, js='(x) => saveItem("saved_settings", x)', queue=False, show_progress="hidden" |
|
) |
|
|
|
temp_saved_settings = gr.JSON(visible=False) |
|
temp_saved_settings.change(lambda x: x, inputs=temp_saved_settings, outputs=saved_settings_df, queue=False, show_progress="hidden") |
|
|
|
setting_items = [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, temp_saved_settings] |
|
reset_button.click(None, inputs=None, outputs=setting_items, |
|
js="() => resetSettings()", show_progress="hidden") |
|
|
|
with gr.TabItem("Chat"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
pdf_file = gr.File(file_count="single", file_types=[".pdf"], |
|
height=80, label="PDF") |
|
context = gr.Textbox(elem_id="context", label="Context", lines=20, |
|
interactive=True, autoscroll=False, show_copy_button=True) |
|
char_counter = gr.Textbox(label="Statistics", value=get_context_info("", []), |
|
lines=2, max_lines=2, interactive=False, container=True) |
|
|
|
pdf_file.upload(update_context_element, inputs=pdf_file, outputs=[context, char_counter], queue=False) |
|
pdf_file.clear(lambda: None, inputs=None, outputs=context, queue=False, show_progress="hidden") |
|
|
|
with gr.Column(scale=2): |
|
|
|
additional_inputs = [context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag] |
|
|
|
with gr.Blocks() as chat: |
|
gr.Markdown(f"# Chat with your PDF") |
|
|
|
with gr.Column(variant="panel"): |
|
chatbot = gr.Chatbot( |
|
CHAT_HISTORY, |
|
elem_id="chatbot", height=500, show_copy_button=True, |
|
sanitize_html=True, render_markdown=True, |
|
latex_delimiters=[ |
|
# { "left": "$$", "right": "$$", "display": True }, |
|
# { "left": "$", "right": "$", "display": False }, |
|
{ "left": "\\(", "right": "\\)", "display": False }, |
|
{ "left": "\\[", "right": "\\]", "display": True }, |
|
], |
|
likeable=False, layout="bubble", |
|
avatar_images=[None, "https://raw.githubusercontent.com/sonoisa/misc/main/resources/icons/chatbot_icon.png"] |
|
) |
|
|
|
message_state = gr.State() |
|
chatbot_state = gr.State(chatbot.value) if chatbot.value else gr.State([]) |
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
message_textbox = gr.Textbox(placeholder="Type a message...", |
|
container=False, show_label=False, autofocus=True, interactive=True, scale=7) |
|
|
|
submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=150) |
|
stop_button = gr.Button("Stop", variant="stop", visible=False, scale=1, min_width=150) |
|
|
|
cost_info = gr.Textbox(elem_id="cost_info", value=get_cost_info(0), |
|
lines=1, max_lines=1, interactive=False, container=False, elem_classes="cost_info") |
|
|
|
with gr.Row(): |
|
retry_button = gr.Button("🔄 Retry", variant="secondary", size="sm") |
|
undo_button = gr.Button("↩️ Undo", variant="secondary", size="sm") |
|
clear_button = gr.Button("🗑️ Clear", variant="secondary", size="sm") |
|
|
|
|
|
def estimate_message_cost(prompt, history, context): |
|
token_count = 0 |
|
messages = get_openai_messages(prompt, history, context) |
|
for message in messages: |
|
tokens = OPENAI_TOKENIZER.encode(message["content"]) |
|
token_count += len(tokens) |
|
|
|
return gr.update(value=get_cost_info(token_count)) |
|
|
|
|
|
message_textbox.change(estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, queue=False, show_progress="hidden") |
|
|
|
example_title_textbox = gr.Textbox(visible=False, interactive=True) |
|
gr.Examples([[k] for k, v in examples.items()], |
|
inputs=example_title_textbox, outputs=message_textbox, |
|
fn=lambda title: examples[title], run_on_click=True) |
|
|
|
|
|
def append_message_to_history(message, history): |
|
message = escape_latex(message) |
|
history.append([message, None]) |
|
return history, history |
|
|
|
|
|
def undo_chat(history): |
|
if history: |
|
message, _ = history.pop() |
|
message = message or "" |
|
else: |
|
message = "" |
|
return history, history, unescape_latex(message) |
|
|
|
|
|
async def submit_message(message, history_with_input, *args): |
|
history = history_with_input[:-1] |
|
inputs = [message, history] |
|
inputs.extend(args) |
|
|
|
generator = process_prompt(*inputs) |
|
message = escape_latex(message) |
|
|
|
has_response = False |
|
async for response in generator: |
|
has_response = True |
|
response = escape_latex(response) |
|
update = history + [[message, response]] |
|
yield update, update |
|
|
|
if not has_response: |
|
update = history + [[message, None]] |
|
yield update, update |
|
|
|
|
|
submit_triggers = [message_textbox.submit, submit_button.click] |
|
|
|
submit_event = gr.events.on( |
|
submit_triggers, lambda message: ("", message), inputs=[message_textbox], outputs=[message_textbox, message_state], queue=False |
|
).then( |
|
append_message_to_history, inputs=[message_state, chatbot_state], outputs=[chatbot, chatbot_state], queue=False |
|
).then( |
|
submit_message, inputs=[message_state, chatbot_state] + additional_inputs, outputs=[chatbot, chatbot_state] |
|
).then( |
|
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden" |
|
) |
|
|
|
for submit_trigger in submit_triggers: |
|
submit_trigger(lambda: (gr.update(visible=False), gr.update(visible=True)), |
|
inputs=None, outputs=[submit_button, stop_button], queue=False) |
|
submit_event.then(lambda: (gr.update(visible=True), gr.update(visible=False)), |
|
inputs=None, outputs=[submit_button, stop_button], queue=False) |
|
|
|
stop_button.click(None, inputs=None, outputs=None, cancels=submit_event) |
|
|
|
retry_button.click( |
|
undo_chat, inputs=[chatbot_state], outputs=[chatbot, chatbot_state, message_state], queue=False |
|
).then( |
|
append_message_to_history, inputs=[message_state, chatbot_state], outputs=[chatbot, chatbot_state], queue=False |
|
).then( |
|
submit_message, inputs=[message_state, chatbot_state] + additional_inputs, outputs=[chatbot, chatbot_state] |
|
).then( |
|
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden" |
|
) |
|
|
|
undo_button.click( |
|
undo_chat, inputs=[chatbot_state], outputs=[chatbot, chatbot_state, message_state], queue=False |
|
).then( |
|
lambda message: message, inputs=message_state, outputs=message_textbox, queue=False |
|
).then( |
|
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden" |
|
) |
|
|
|
clear_button.click( |
|
lambda: ([], [], None), inputs=None, outputs=[chatbot, chatbot_state, message_state], |
|
queue=False |
|
).then( |
|
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden" |
|
) |
|
|
|
chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None, |
|
# チャット履歴をクエリパラメータに保存する。 |
|
js=save_or_delete_chat_history, queue=False, show_progress="hidden") |
|
|
|
save_chat_history_to_url.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None, |
|
js=save_or_delete_chat_history, queue=False, show_progress="hidden") |
|
|
|
context.change( |
|
count_characters, inputs=context, outputs=char_counter, queue=False, show_progress="hidden" |
|
).then( |
|
create_search_engine, inputs=context, outputs=None |
|
).then( |
|
estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden" |
|
) |
|
|
|
app.load(None, inputs=None, outputs=setting_items, js=js_define_utilities_and_load_settings, show_progress="hidden") |
|
|
|
app.queue().launch() |
|
|
|
main() |
|
</gradio-file> |
|
</gradio-lite> |
|
|
|
<script language="javascript" src="https://cdn.jsdelivr.net/npm/lz-string@1.5.0/libs/lz-string.min.js"></script> |
|
<script language="javascript"> |
|
(function () { |
|
|
|
const url = new URL(window.location.href); |
|
|
|
if (url.searchParams.has("history")) { |
|
const compressedHistory = url.searchParams.get("history"); |
|
hist = LZString.decompressFromEncodedURIComponent(compressedHistory); |
|
|
|
const chat_history_element = document.querySelector('gradio-file[name="chat_history.json"]'); |
|
chat_history_element.textContent = hist; |
|
} |
|
})(); |
|
</script> |
|
<script type="module" crossorigin src="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.js"></script> |
|
</body> |
|
</html> |
|
|