Add a flag to enable/disable RAG feature
Browse files- index.html +40 -26
index.html
CHANGED
@@ -553,7 +553,7 @@ actual_total_cost_prompt = 0
|
|
553 |
actual_total_cost_completion = 0
|
554 |
|
555 |
|
556 |
-
async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature):
|
557 |
"""
|
558 |
ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。
|
559 |
|
@@ -569,6 +569,7 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
|
|
569 |
model_name (str): 使用するAIモデルの名前
|
570 |
max_tokens (int): 生成する最大トークン数
|
571 |
temperature (float): クリエイティビティの度合いを示す温度パラメータ
|
|
|
572 |
|
573 |
Returns:
|
574 |
str: ChatGPTによる生成結果
|
@@ -593,15 +594,24 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
|
|
593 |
http_client=http_client
|
594 |
)
|
595 |
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
605 |
|
606 |
bot_response = ""
|
607 |
if hasattr(completion, "error"):
|
@@ -742,6 +752,7 @@ DEFAULT_SETTINGS = {
|
|
742 |
"model_name": "gpt-4-turbo-preview",
|
743 |
"max_tokens": 4096,
|
744 |
"temperature": 0.2,
|
|
|
745 |
"save_chat_history_to_url": False
|
746 |
};
|
747 |
|
@@ -791,6 +802,7 @@ def main():
|
|
791 |
entry["model_name"] || default_model_name,
|
792 |
entry["max_tokens"] || default_max_tokens,
|
793 |
entry["temperature"] || default_temperature,
|
|
|
794 |
entry["save_chat_history_to_url"] || default_save_chat_history_to_url
|
795 |
]);
|
796 |
}
|
@@ -798,7 +810,7 @@ def main():
|
|
798 |
saved_settings = default_saved_settings;
|
799 |
}
|
800 |
|
801 |
-
return [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url, saved_settings];
|
802 |
};
|
803 |
|
804 |
globalThis.resetSettings = () => {
|
@@ -1007,6 +1019,10 @@ def main():
|
|
1007 |
temperature.change(None, inputs=temperature, outputs=None,
|
1008 |
js='(x) => saveItem("temperature", x)', show_progress="hidden")
|
1009 |
|
|
|
|
|
|
|
|
|
1010 |
save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
|
1011 |
|
1012 |
reset_button = gr.Button("Reset Settings")
|
@@ -1017,10 +1033,10 @@ def main():
|
|
1017 |
saved_settings_df = gr.Dataframe(
|
1018 |
elem_id="saved_settings",
|
1019 |
value=[default_saved_settings],
|
1020 |
-
headers=["Name", "Platform", "Endpoint", "Azure Deployment", "Azure API Version", "Model", "Max Tokens", "Temperature", "Save Chat History to URL"],
|
1021 |
row_count=(0, "dynamic"),
|
1022 |
-
col_count=(
|
1023 |
-
datatype=["str", "str", "str", "str", "str", "str", "number", "number", "bool"],
|
1024 |
type="array",
|
1025 |
label="Saved Settings",
|
1026 |
show_label=True,
|
@@ -1057,15 +1073,15 @@ def main():
|
|
1057 |
|
1058 |
row_index = selected_setting[0]
|
1059 |
|
1060 |
-
setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url = saved_settings[row_index]
|
1061 |
|
1062 |
-
return u(setting_name), u(platform), u(endpoint), u(azure_deployment), u(azure_api_version), u(model_name), u(max_tokens), u(temperature), u(save_chat_history_to_url), None
|
1063 |
|
1064 |
|
1065 |
-
load_saved_settings_button.click(load_saved_setting, inputs=[saved_settings_df, selected_setting], outputs=[setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url, selected_setting], queue=False, show_progress="hidden")
|
1066 |
|
1067 |
|
1068 |
-
def append_or_overwrite_setting(saved_settings, setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url):
|
1069 |
|
1070 |
setting_name = setting_name.strip()
|
1071 |
|
@@ -1073,13 +1089,13 @@ def main():
|
|
1073 |
new_saved_settings = []
|
1074 |
for entry in saved_settings:
|
1075 |
if entry[0] == setting_name:
|
1076 |
-
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, save_chat_history_to_url])
|
1077 |
found = True
|
1078 |
else:
|
1079 |
new_saved_settings.append(entry)
|
1080 |
|
1081 |
if not found:
|
1082 |
-
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, save_chat_history_to_url])
|
1083 |
|
1084 |
return new_saved_settings, None
|
1085 |
|
@@ -1095,7 +1111,7 @@ def main():
|
|
1095 |
|
1096 |
|
1097 |
append_or_overwrite_saved_settings_button.click(
|
1098 |
-
append_or_overwrite_setting, inputs=[saved_settings_df, setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, save_chat_history_to_url], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden"
|
1099 |
).then(
|
1100 |
serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden",
|
1101 |
).then(
|
@@ -1127,8 +1143,7 @@ def main():
|
|
1127 |
temp_saved_settings = gr.JSON(visible=False)
|
1128 |
temp_saved_settings.change(lambda x: x, inputs=temp_saved_settings, outputs=saved_settings_df, queue=False, show_progress="hidden")
|
1129 |
|
1130 |
-
setting_items = [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens,
|
1131 |
-
temperature, save_chat_history_to_url, temp_saved_settings]
|
1132 |
reset_button.click(None, inputs=None, outputs=setting_items,
|
1133 |
js="() => resetSettings()", show_progress="hidden")
|
1134 |
|
@@ -1147,8 +1162,7 @@ def main():
|
|
1147 |
|
1148 |
with gr.Column(scale=2):
|
1149 |
|
1150 |
-
additional_inputs = [context, platform, endpoint, azure_deployment, azure_api_version, api_key,
|
1151 |
-
model_name, max_tokens, temperature]
|
1152 |
|
1153 |
with gr.Blocks() as chat:
|
1154 |
gr.Markdown(f"# Chat with your PDF")
|
@@ -1326,4 +1340,4 @@ main()
|
|
1326 |
</script>
|
1327 |
<script type="module" crossorigin src="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.js"></script>
|
1328 |
</body>
|
1329 |
-
</html>
|
|
|
553 |
actual_total_cost_completion = 0
|
554 |
|
555 |
|
556 |
+
async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag):
|
557 |
"""
|
558 |
ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。
|
559 |
|
|
|
569 |
model_name (str): 使用するAIモデルの名前
|
570 |
max_tokens (int): 生成する最大トークン数
|
571 |
temperature (float): クリエイティビティの度合いを示す温度パラメータ
|
572 |
+
enable_rag (bool): RAG機能を有効にするかどうか
|
573 |
|
574 |
Returns:
|
575 |
str: ChatGPTによる生成結果
|
|
|
594 |
http_client=http_client
|
595 |
)
|
596 |
|
597 |
+
if enable_rag:
|
598 |
+
completion = openai_client.chat.completions.create(
|
599 |
+
messages=messages,
|
600 |
+
model=model_name,
|
601 |
+
max_tokens=max_tokens,
|
602 |
+
temperature=temperature,
|
603 |
+
tools=CHAT_TOOLS,
|
604 |
+
tool_choice="auto",
|
605 |
+
stream=False
|
606 |
+
)
|
607 |
+
else:
|
608 |
+
completion = openai_client.chat.completions.create(
|
609 |
+
messages=messages,
|
610 |
+
model=model_name,
|
611 |
+
max_tokens=max_tokens,
|
612 |
+
temperature=temperature,
|
613 |
+
stream=False
|
614 |
+
)
|
615 |
|
616 |
bot_response = ""
|
617 |
if hasattr(completion, "error"):
|
|
|
752 |
"model_name": "gpt-4-turbo-preview",
|
753 |
"max_tokens": 4096,
|
754 |
"temperature": 0.2,
|
755 |
+
"enable_rag": True,
|
756 |
"save_chat_history_to_url": False
|
757 |
};
|
758 |
|
|
|
802 |
entry["model_name"] || default_model_name,
|
803 |
entry["max_tokens"] || default_max_tokens,
|
804 |
entry["temperature"] || default_temperature,
|
805 |
+
entry["enable_rag"] || default_enable_rag,
|
806 |
entry["save_chat_history_to_url"] || default_save_chat_history_to_url
|
807 |
]);
|
808 |
}
|
|
|
810 |
saved_settings = default_saved_settings;
|
811 |
}
|
812 |
|
813 |
+
return [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, saved_settings];
|
814 |
};
|
815 |
|
816 |
globalThis.resetSettings = () => {
|
|
|
1019 |
temperature.change(None, inputs=temperature, outputs=None,
|
1020 |
js='(x) => saveItem("temperature", x)', show_progress="hidden")
|
1021 |
|
1022 |
+
enable_rag = gr.Checkbox(label="Enable RAG (Retrieval Augmented Generation)", interactive=True)
|
1023 |
+
enable_rag.change(None, inputs=enable_rag, outputs=None,
|
1024 |
+
js='(x) => saveItem("enable_rag", x)', show_progress="hidden")
|
1025 |
+
|
1026 |
save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
|
1027 |
|
1028 |
reset_button = gr.Button("Reset Settings")
|
|
|
1033 |
saved_settings_df = gr.Dataframe(
|
1034 |
elem_id="saved_settings",
|
1035 |
value=[default_saved_settings],
|
1036 |
+
headers=["Name", "Platform", "Endpoint", "Azure Deployment", "Azure API Version", "Model", "Max Tokens", "Temperature", "Enable RAG", "Save Chat History to URL"],
|
1037 |
row_count=(0, "dynamic"),
|
1038 |
+
col_count=(10, "fixed"),
|
1039 |
+
datatype=["str", "str", "str", "str", "str", "str", "number", "number", "bool", "bool"],
|
1040 |
type="array",
|
1041 |
label="Saved Settings",
|
1042 |
show_label=True,
|
|
|
1073 |
|
1074 |
row_index = selected_setting[0]
|
1075 |
|
1076 |
+
setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url = saved_settings[row_index]
|
1077 |
|
1078 |
+
return u(setting_name), u(platform), u(endpoint), u(azure_deployment), u(azure_api_version), u(model_name), u(max_tokens), u(temperature), u(enable_rag), u(save_chat_history_to_url), None
|
1079 |
|
1080 |
|
1081 |
+
load_saved_settings_button.click(load_saved_setting, inputs=[saved_settings_df, selected_setting], outputs=[setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, selected_setting], queue=False, show_progress="hidden")
|
1082 |
|
1083 |
|
1084 |
+
def append_or_overwrite_setting(saved_settings, setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url):
|
1085 |
|
1086 |
setting_name = setting_name.strip()
|
1087 |
|
|
|
1089 |
new_saved_settings = []
|
1090 |
for entry in saved_settings:
|
1091 |
if entry[0] == setting_name:
|
1092 |
+
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url])
|
1093 |
found = True
|
1094 |
else:
|
1095 |
new_saved_settings.append(entry)
|
1096 |
|
1097 |
if not found:
|
1098 |
+
new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url])
|
1099 |
|
1100 |
return new_saved_settings, None
|
1101 |
|
|
|
1111 |
|
1112 |
|
1113 |
append_or_overwrite_saved_settings_button.click(
|
1114 |
+
append_or_overwrite_setting, inputs=[saved_settings_df, setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden"
|
1115 |
).then(
|
1116 |
serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden",
|
1117 |
).then(
|
|
|
1143 |
temp_saved_settings = gr.JSON(visible=False)
|
1144 |
temp_saved_settings.change(lambda x: x, inputs=temp_saved_settings, outputs=saved_settings_df, queue=False, show_progress="hidden")
|
1145 |
|
1146 |
+
setting_items = [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, temp_saved_settings]
|
|
|
1147 |
reset_button.click(None, inputs=None, outputs=setting_items,
|
1148 |
js="() => resetSettings()", show_progress="hidden")
|
1149 |
|
|
|
1162 |
|
1163 |
with gr.Column(scale=2):
|
1164 |
|
1165 |
+
additional_inputs = [context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag]
|
|
|
1166 |
|
1167 |
with gr.Blocks() as chat:
|
1168 |
gr.Markdown(f"# Chat with your PDF")
|
|
|
1340 |
</script>
|
1341 |
<script type="module" crossorigin src="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.js"></script>
|
1342 |
</body>
|
1343 |
+
</html>
|