|
from response_parser import * |
|
import copy |
|
import json |
|
from tqdm import tqdm |
|
import logging |
|
import argparse |
|
import os |
|
|
|
def initialization(state_dict: Dict) -> None: |
|
if not os.path.exists('cache'): |
|
os.mkdir('cache') |
|
if state_dict["bot_backend"] is None: |
|
state_dict["bot_backend"] = BotBackend() |
|
if 'OPENAI_API_KEY' in os.environ: |
|
del os.environ['OPENAI_API_KEY'] |
|
|
|
def get_bot_backend(state_dict: Dict) -> BotBackend: |
|
return state_dict["bot_backend"] |
|
|
|
def switch_to_gpt4(state_dict: Dict, whether_switch: bool) -> None: |
|
bot_backend = get_bot_backend(state_dict) |
|
if whether_switch: |
|
bot_backend.update_gpt_model_choice("GPT-4") |
|
else: |
|
bot_backend.update_gpt_model_choice("GPT-3.5") |
|
|
|
def add_text(state_dict, history, text): |
|
bot_backend = get_bot_backend(state_dict) |
|
bot_backend.add_text_message(user_text=text) |
|
history = history + [[text, None]] |
|
return history, state_dict |
|
|
|
def bot(state_dict, history): |
|
bot_backend = get_bot_backend(state_dict) |
|
while bot_backend.finish_reason in ('new_input', 'function_call'): |
|
if history[-1][1]: |
|
history.append([None, ""]) |
|
else: |
|
history[-1][1] = "" |
|
logging.info("Start chat completion") |
|
response = chat_completion(bot_backend=bot_backend) |
|
logging.info(f"End chat completion, response: {response}") |
|
|
|
logging.info("Start parse response") |
|
history, _ = parse_response( |
|
chunk=response, |
|
history=history, |
|
bot_backend=bot_backend |
|
) |
|
logging.info("End parse response") |
|
return history |
|
|
|
def main(state, history, user_input): |
|
history, state = add_text(state, history, user_input) |
|
last_history = copy.deepcopy(history) |
|
first_turn_flag = False |
|
while True: |
|
if first_turn_flag: |
|
switch_to_gpt4(state, False) |
|
first_turn_flag = False |
|
else: |
|
switch_to_gpt4(state, True) |
|
logging.info("Start bot") |
|
history = bot(state, history) |
|
logging.info("End bot") |
|
print(state["bot_backend"].conversation) |
|
if last_history == copy.deepcopy(history): |
|
logging.info("No new response, end conversation") |
|
conversation = [item for item in state["bot_backend"].conversation if item["content"]] |
|
return conversation |
|
else: |
|
logging.info("New response, continue conversation") |
|
last_history = copy.deepcopy(history) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input_path', type=str) |
|
parser.add_argument('--output_path', type=str) |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logging.info("Initialization") |
|
|
|
state = {"bot_backend": None} |
|
history = [] |
|
|
|
initialization(state) |
|
switch_to_gpt4(state_dict=state, whether_switch=True) |
|
|
|
logging.info("Start") |
|
with open(args.input_path, "r") as f: |
|
instructions = [json.loads(line)["query"] for line in f.readlines()] |
|
all_history = [] |
|
logging.info(f"{len(instructions)} remaining instructions for {args.input_path}") |
|
|
|
for user_input_index, user_input in enumerate(tqdm(instructions)): |
|
logging.info(f"Start conversation {user_input_index}") |
|
conversation = main(state, history, user_input) |
|
all_history.append( |
|
{ |
|
"instruction": user_input, |
|
"conversation": conversation |
|
} |
|
) |
|
with open(f"{args.output_path}", "w") as f: |
|
json.dump(all_history, f, indent=4, ensure_ascii=False) |
|
state["bot_backend"].restart() |
|
|