gschatbot_1 / chatbot_utils.py
songhune's picture
์ด๊ฒŒ ๋งž์ง€
6af28fa verified
import os
from openai import OpenAI
import json
from datetime import datetime
from scenario_handler import ScenarioHandler
import time
client = OpenAI(api_key=os.getenv("api_key"))
def chatbot_response(response, handler_type='offender', n=1):
scenario_handler = ScenarioHandler()
if handler_type == 'offender':
scenario_messages = scenario_handler.handle_offender()
else:
scenario_messages = scenario_handler.handle_victim()
messages = [{"role": "system", "content": "You are a chatbot."}]
messages.extend(scenario_messages)
messages.append({"role": "user", "content": response})
api_response = client.chat.completions.create(
model="gpt-4",
temperature=0.8,
top_p=0.9,
max_tokens=300,
n=n,
frequency_penalty=0.5,
presence_penalty=0.5,
messages=messages
)
choices = [choice.message.content for choice in api_response.choices]
return choices[0], choices
def save_history(history):
os.makedirs('logs', exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join('logs', f'chat_history_{timestamp}.json')
with open(filename, 'w', encoding='utf-8') as file:
json.dump(history, file, ensure_ascii=False, indent=4)
print(f"History saved to {filename}")
def process_user_input(user_input, chatbot_history):
if user_input.strip().lower() == "์ข…๋ฃŒ":
save_history(chatbot_history)
return chatbot_history + [("์ข…๋ฃŒ", "์‹คํ—˜์— ์ฐธ๊ฐ€ํ•ด ์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ํ›„์† ์ง€์‹œ๋ฅผ ๋”ฐ๋ผ์ฃผ์„ธ์š”")], []
# First, add the user's input to the history
new_history = chatbot_history + [(user_input, None)]
# Then, get the offender's response
offender_response, _ = chatbot_response(user_input, 'offender', n=1)
# Generate victim choices for the next turn
_, victim_choices = chatbot_response(offender_response, 'victim', n=3)
return new_history, offender_response, victim_choices
def delayed_offender_response(history, offender_response):
# This function will be called after a delay to add the offender's response
return history + [(None, offender_response)]