alKoGolik's picture
Upload 210 files
6fcd376 verified
from response_parser import *
import copy
import json
from tqdm import tqdm
import logging
import argparse
import os
def initialization(state_dict: Dict) -> None:
if not os.path.exists('cache'):
os.mkdir('cache')
if state_dict["bot_backend"] is None:
state_dict["bot_backend"] = BotBackend()
if 'OPENAI_API_KEY' in os.environ:
del os.environ['OPENAI_API_KEY']
def get_bot_backend(state_dict: Dict) -> BotBackend:
return state_dict["bot_backend"]
def switch_to_gpt4(state_dict: Dict, whether_switch: bool) -> None:
bot_backend = get_bot_backend(state_dict)
if whether_switch:
bot_backend.update_gpt_model_choice("GPT-4")
else:
bot_backend.update_gpt_model_choice("GPT-3.5")
def add_text(state_dict, history, text):
bot_backend = get_bot_backend(state_dict)
bot_backend.add_text_message(user_text=text)
history = history + [[text, None]]
return history, state_dict
def bot(state_dict, history):
bot_backend = get_bot_backend(state_dict)
while bot_backend.finish_reason in ('new_input', 'function_call'):
if history[-1][1]:
history.append([None, ""])
else:
history[-1][1] = ""
logging.info("Start chat completion")
response = chat_completion(bot_backend=bot_backend)
logging.info(f"End chat completion, response: {response}")
logging.info("Start parse response")
history, _ = parse_response(
chunk=response,
history=history,
bot_backend=bot_backend
)
logging.info("End parse response")
return history
def main(state, history, user_input):
history, state = add_text(state, history, user_input)
last_history = copy.deepcopy(history)
first_turn_flag = False
while True:
if first_turn_flag:
switch_to_gpt4(state, False)
first_turn_flag = False
else:
switch_to_gpt4(state, True)
logging.info("Start bot")
history = bot(state, history)
logging.info("End bot")
print(state["bot_backend"].conversation)
if last_history == copy.deepcopy(history):
logging.info("No new response, end conversation")
conversation = [item for item in state["bot_backend"].conversation if item["content"]]
return conversation
else:
logging.info("New response, continue conversation")
last_history = copy.deepcopy(history)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str)
parser.add_argument('--output_path', type=str)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logging.info("Initialization")
state = {"bot_backend": None}
history = []
initialization(state)
switch_to_gpt4(state_dict=state, whether_switch=True)
logging.info("Start")
with open(args.input_path, "r") as f:
instructions = [json.loads(line)["query"] for line in f.readlines()]
all_history = []
logging.info(f"{len(instructions)} remaining instructions for {args.input_path}")
for user_input_index, user_input in enumerate(tqdm(instructions)):
logging.info(f"Start conversation {user_input_index}")
conversation = main(state, history, user_input)
all_history.append(
{
"instruction": user_input,
"conversation": conversation
}
)
with open(f"{args.output_path}", "w") as f:
json.dump(all_history, f, indent=4, ensure_ascii=False)
state["bot_backend"].restart()