Spaces:
No application file
No application file
import argparse | |
import json | |
import pandas as pd | |
import os | |
import time | |
import concurrent.futures | |
import tqdm | |
import yaml | |
import random | |
import threading | |
import orjson | |
from category import Category | |
LOCK = threading.RLock() | |
TASKS = None | |
CACHE_DICT = None | |
OUTPUT_DICT = None | |
# API setting constants | |
API_MAX_RETRY = None | |
API_RETRY_SLEEP = None | |
API_ERROR_OUTPUT = None | |
# load config args from config yaml files | |
def make_config(config_file: str) -> dict: | |
config_kwargs = {} | |
with open(config_file, "r") as f: | |
config_kwargs = yaml.load(f, Loader=yaml.SafeLoader) | |
return config_kwargs | |
def get_endpoint(endpoint_list): | |
if endpoint_list is None: | |
return None | |
assert endpoint_list is not None | |
# randomly pick one | |
api_dict = random.choices(endpoint_list)[0] | |
return api_dict | |
def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None): | |
import openai | |
if api_dict: | |
client = openai.OpenAI( | |
base_url=api_dict["api_base"], | |
api_key=api_dict["api_key"], | |
) | |
else: | |
client = openai.OpenAI() | |
output = API_ERROR_OUTPUT | |
for _ in range(API_MAX_RETRY): | |
try: | |
# print(messages) | |
completion = client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
# extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None, | |
) | |
output = completion.choices[0].message.content | |
# print(output) | |
break | |
except openai.RateLimitError as e: | |
print(type(e), e) | |
time.sleep(API_RETRY_SLEEP) | |
except openai.BadRequestError as e: | |
print(messages) | |
print(type(e), e) | |
break | |
except openai.APIConnectionError as e: | |
print(messages) | |
print(type(e), e) | |
time.sleep(API_RETRY_SLEEP) | |
except openai.InternalServerError as e: | |
print(messages) | |
print(type(e), e) | |
time.sleep(API_RETRY_SLEEP) | |
except Exception as e: | |
print(type(e), e) | |
break | |
return output | |
def get_answer( | |
question: dict, | |
model_name: str, | |
max_tokens: int, | |
temperature: float, | |
answer_file: str, | |
api_dict: dict, | |
categories: list, | |
testing: bool, | |
): | |
if "category_tag" in question: | |
category_tag = question["category_tag"] | |
else: | |
category_tag = {} | |
output_log = {} | |
for category in categories: | |
conv = category.pre_process(question["prompt"]) | |
output = chat_completion_openai( | |
model=model_name, | |
messages=conv, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
api_dict=api_dict, | |
) | |
# Dump answers | |
category_tag[category.name_tag] = category.post_process(output) | |
if testing: | |
output_log[category.name_tag] = output | |
question["category_tag"] = category_tag | |
if testing: | |
question["output_log"] = output_log | |
question.drop(["prompt", "uid", "required_tasks"], inplace=True) | |
with LOCK: | |
with open(answer_file, "a") as fout: | |
fout.write(json.dumps(question.to_dict()) + "\n") | |
def category_merge(row): | |
id = row["uid"] | |
input_category = row["category_tag"] if "category_tag" in row else {} | |
cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} | |
output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} | |
# tries to fill in missing categories using cache first, then output | |
for name in TASKS: | |
if name not in input_category: | |
if name in cache_category: | |
input_category[name] = cache_category[name] | |
continue | |
if name in output_category: | |
input_category[name] = output_category[name] | |
return input_category | |
def find_required_tasks(row): | |
id = row["uid"] | |
input_category = row["category_tag"] if "category_tag" in row else {} | |
cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} | |
output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} | |
return [ | |
name | |
for name in TASKS | |
if not ( | |
name in input_category or name in cache_category or name in output_category | |
) | |
] | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, required=True) | |
parser.add_argument("--testing", action="store_true") | |
args = parser.parse_args() | |
enter = input( | |
"Make sure your config file is properly configured. Press enter to continue." | |
) | |
if not enter == "": | |
exit() | |
config = make_config(args.config) | |
API_MAX_RETRY = config["max_retry"] | |
API_RETRY_SLEEP = config["retry_sleep"] | |
API_ERROR_OUTPUT = config["error_output"] | |
categories = [Category.create_category(name) for name in config["task_name"]] | |
TASKS = config["task_name"] | |
print( | |
f"Following categories will be labeled:\n{[category.name_tag for category in categories]}" | |
) | |
print("loading input data (might take min)") | |
with open(config["input_file"], "rb") as f: | |
data = orjson.loads(f.read()) | |
input_data = pd.DataFrame(data) | |
# much faster than pd.apply | |
input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) | |
assert len(input_data) == len(input_data.uid.unique()) | |
print(f"{len(input_data)}# of input data just loaded") | |
if config["cache_file"]: | |
print("loading cache data") | |
with open(config["cache_file"], "rb") as f: | |
data = orjson.loads(f.read()) | |
cache_data = pd.DataFrame(data) | |
cache_data["uid"] = cache_data.question_id.map(str) + cache_data.tstamp.map(str) | |
assert len(cache_data) == len(cache_data.uid.unique()) | |
print(f"{len(cache_data)}# of cache data just loaded") | |
assert "category_tag" in cache_data.columns | |
cache_dict = cache_data[["uid", "category_tag"]].set_index("uid") | |
print("finalizing cache_dict (should take less than 30 sec)") | |
CACHE_DICT = cache_dict.to_dict("index") | |
else: | |
CACHE_DICT = {} | |
if os.path.isfile(config["output_file"]): | |
print("loading existing output") | |
output_data = pd.read_json(config["output_file"], lines=True) | |
output_data["uid"] = output_data.question_id.map(str) + output_data.tstamp.map( | |
str | |
) | |
assert len(output_data) == len(output_data.uid.unique()) | |
print(f"{len(output_data)}# of existing output just loaded") | |
assert "category_tag" in output_data.columns | |
output_dict = output_data[["uid", "category_tag"]].set_index("uid") | |
print("finalizing output_dict (should take less than 30 sec)") | |
OUTPUT_DICT = output_dict.to_dict("index") | |
else: | |
OUTPUT_DICT = {} | |
print( | |
"finding tasks needed to run... (should take around 1 minute or less on large dataset)" | |
) | |
input_data["required_tasks"] = input_data.apply(find_required_tasks, axis=1) | |
not_labeled = input_data[input_data.required_tasks.map(lambda x: len(x) > 0)].copy() | |
print(f"{len(not_labeled)} # of conversations needs to be labeled") | |
for name in TASKS: | |
print( | |
f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}" | |
) | |
not_labeled["prompt"] = not_labeled.conversation_a.map( | |
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) | |
) | |
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) | |
with concurrent.futures.ThreadPoolExecutor( | |
max_workers=config["parallel"] | |
) as executor: | |
futures = [] | |
for index, row in tqdm.tqdm(not_labeled.iterrows()): | |
future = executor.submit( | |
get_answer, | |
row, | |
config["model_name"], | |
config["max_token"], | |
config["temperature"], | |
config["output_file"], | |
get_endpoint(config["endpoints"]), | |
[ | |
category | |
for category in categories | |
if category.name_tag in row["required_tasks"] | |
], | |
args.testing, | |
) | |
futures.append(future) | |
for future in tqdm.tqdm( | |
concurrent.futures.as_completed(futures), total=len(futures) | |
): | |
future.result() | |
if config["convert_to_json"]: | |
# merge two data frames, but only take the fields from the cache data to overwrite the input data | |
merge_columns = [category.name_tag for category in categories] | |
print(f"Columns to be merged:\n{merge_columns}") | |
input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) | |
assert len(input_data) == len(input_data.uid.unique()) | |
# fastest way to merge | |
assert os.path.isfile(config["output_file"]) | |
print("reading output file...") | |
temp = pd.read_json(config["output_file"], lines=True) | |
temp["uid"] = temp.question_id.map(str) + temp.tstamp.map(str) | |
assert len(temp) == len(temp.uid.unique()) | |
assert "category_tag" in temp.columns | |
output_dict = temp[["uid", "category_tag"]].set_index("uid") | |
print("finalizing output_dict (should take less than 30 sec)") | |
OUTPUT_DICT = output_dict.to_dict("index") | |
print("begin merging (should take around 1 minute or less on large dataset)") | |
input_data["category_tag"] = input_data.apply(category_merge, axis=1) | |
print("merge completed") | |
final_data = input_data.drop( | |
columns=["prompt", "uid", "required_tasks"], errors="ignore" | |
) | |
final_data.to_json( | |
config["output_file"][:-1], orient="records", indent=4, force_ascii=False | |
) | |