youngtsai's picture
role1_content = f"{structured_dialogue[i]['content']}"
cb56cc7
raw
history blame
5.74 kB
import gradio as gr
from gtts import gTTS
import json
import os
import openai
import re
PASSWORD = os.environ['PASSWORD']
OPEN_AI_KEY = os.environ['OPEN_AI_KEY']
def validate_and_correct_chat(data, roles=["A", "B"], rounds=2):
"""
Corrects the chat data to ensure proper roles and number of rounds.
Parameters:
- data (list): The chat data list of dicts, e.g. [{"role": "A", "content": "Hi"}, ...]
- roles (list): The expected roles, default is ["A", "B"]
- rounds (int): The number of rounds expected
Returns:
- list: Corrected chat data
"""
# Validate role names
for item in data:
if item['role'] not in roles:
print(f"Invalid role '{item['role']}' detected. Correcting it.")
# We will change the role to the next expected role in the sequence.
prev_index = roles.index(data[data.index(item) - 1]['role'])
next_index = (prev_index + 1) % len(roles)
item['role'] = roles[next_index]
# Validate number of rounds
expected_entries = rounds * len(roles)
if len(data) > expected_entries:
print(f"Too many rounds detected. Trimming the chat to {rounds} rounds.")
data = data[:expected_entries]
return data
def extract_json_from_response(response_text):
# 使用正則表達式匹配 JSON 格式的對話
match = re.search(r'\[\s*\{.*?\}\s*\]', response_text, re.DOTALL)
if match:
json_str = match.group(0)
return json.loads(json_str)
else:
raise ValueError("JSON dialogue not found in the response.")
def create_chat_dialogue(rounds, role1, role2, theme, language):
openai.api_key = os.environ["OPEN_AI_KEY"]
# 初始化對話
sentenses_count = int(rounds) * 2
sys_content = f"你是一個{language}家教,請用{language}生成對話"
prompt = f"您將進行一場以{theme}為主題的對話。{role1}{role2}將是參與者。請依次交談{rounds}輪。(1輪對話的定義是 {role1}{role2} 各說一句話,總共 {sentenses_count} 句話。)以json格式儲存對話。並回傳對話JSON文件。格式為:[{{role:\"{role1}\", content: \".....\"}}, {{role:\"{role2}\", content: \".....\"}}]"
messages = [
{"role": "system", "content": sys_content},
{"role": "user", "content": prompt}
]
print("=====messages=====")
print(messages)
print("=====messages=====")
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=int(500 * int(rounds)) # 設定一個較大的值,可根據需要調整
)
print(response)
response_text = response.choices[0].message['content'].strip()
extract_json = extract_json_from_response(response_text)
dialogue = validate_and_correct_chat(data=extract_json, roles=[role1, role2], rounds=rounds)
print(dialogue)
# 這裡直接返回JSON格式的對話,但考慮到這可能只是一個字符串,您可能還需要將它解析為一個Python對象
return dialogue
def generate_dialogue(rounds, method, role1, role2, theme, language):
if method == "auto":
dialogue = create_chat_dialogue(rounds, role1, role2, theme, language)
else:
dialogue = [{"role": role1, "content": "手動輸入文本 1"}, {"role": role2, "content": "手動輸入文本 2"}]
return dialogue
def main_function(password: str, theme: str, language: str, method: str, rounds: int, role1: str, role2: str):
if password != os.environ.get("PASSWORD", ""):
return "错误的密码,请重新输入。", "", ""
structured_dialogue = generate_dialogue(rounds, method, role1, role2, theme, language)
# Convert structured dialogue for Chatbot component to show "role1: content1" and "role2: content2" side by side
chatbot_dialogue = []
for i in range(0, len(structured_dialogue), 2): # We iterate with a step of 2 to take pairs
# Get the content for the two roles in the pair
role1_content = f"{structured_dialogue[i]['content']}"
role2_content = f"{structured_dialogue[i+1]['content']}" if i+1 < len(structured_dialogue) else ""
chatbot_dialogue.append((role1_content, role2_content))
audio_path = dialogue_to_audio(structured_dialogue)
json_output = json.dumps({"dialogue": structured_dialogue}, ensure_ascii=False, indent=4)
# 儲存對話為 JSON 文件
file_name = "dialogue_output.txt"
with open(file_name, "w", encoding="utf-8") as f:
f.write(json_output)
return chatbot_dialogue, audio_path, file_name
def dialogue_to_audio(dialogue):
text = " ".join([f"{item['role']}: {item['content']}" for item in dialogue])
tts = gTTS(text=text, lang='zh-tw')
file_path = "temp_audio.mp3"
tts.save(file_path)
return file_path
if __name__ == "__main__":
gr.Interface(
main_function,
[
gr.components.Textbox(label="输入密码", type="password"),
gr.components.Textbox(label="對話主題"), # 加入 theme 的輸入框,設定預設值為 '購物'
gr.components.Dropdown(choices=["中文", "英文"], label="語言"),
gr.components.Dropdown(choices=["auto", "manual"], label="生成方式"),
gr.components.Slider(minimum=2, maximum=6, step=2, label="對話輪數"),
gr.components.Textbox(label="角色 1 名稱"),
gr.components.Textbox(label="角色 2 名稱"),
],
[
gr.components.Chatbot(label="生成的對話"),
gr.components.Audio(type="filepath", label="對話朗讀"),
gr.components.File(label="下載對話 JSON 文件")
]
).launch()