Spaces:
Runtime error
Runtime error
import collections | |
import lm_eval.tasks | |
import random | |
import time | |
from datetime import datetime as dt | |
import bittensor as bt | |
from tqdm import tqdm | |
import json | |
import http.client | |
import os | |
from argparse import ArgumentParser | |
parser = ArgumentParser() | |
parser.add_argument("--validator", required=True, type=str, help="validator name", choices=["opentensor_foundation", "taostats"], default="float16") | |
args = parser.parse_args() | |
default_prompt = ''' | |
You are Chattensor. | |
Chattensor is a research project by Opentensor Cortex. | |
Chattensor is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Chattensor is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. | |
''' | |
if args.validator == "taostats": | |
print("TAOOOOSATATS") | |
try: | |
bitapai_key = os.environ["BITAPAI_KEY"] | |
conn = http.client.HTTPSConnection("dashboard.bitapai.io") | |
headers = { | |
'Content-Type': 'application/json', | |
'X-API-KEY': bitapai_key | |
} | |
except KeyError: | |
raise RuntimeError(f"BITAPAI_KEY does not exist and chosen validator is taostats. Please set your bitapai key using export BITAPAI_KEY=x.") | |
def get_response(prompt): | |
if args.validator == "taostats": | |
payload = json.dumps({ | |
"system": default_prompt, | |
"user": prompt | |
}) | |
conn.request("POST", "/api/v1/prompt", payload, headers) | |
res = conn.getresponse() | |
data = res.read() | |
# print('test') | |
print(data) | |
time.sleep(1) | |
return data.decode("utf-8") | |
else: | |
return bt.prompt(prompt) | |
# Load all the LMEH tasks | |
tasks = ["hellaswag", "arc_challenge", "truthfulqa_mc", "hendrycksTest-abstract_algebra", "hendrycksTest-anatomy", "hendrycksTest-astronomy", "hendrycksTest-business_ethics", "hendrycksTest-clinical_knowledge", "hendrycksTest-college_biology", "hendrycksTest-college_chemistry", "hendrycksTest-college_computer_science", "hendrycksTest-college_mathematics", "hendrycksTest-college_medicine", "hendrycksTest-college_physics", "hendrycksTest-computer_security", "hendrycksTest-conceptual_physics", "hendrycksTest-econometrics", "hendrycksTest-electrical_engineering", "hendrycksTest-elementary_mathematics", "hendrycksTest-formal_logic", "hendrycksTest-global_facts", "hendrycksTest-high_school_biology", "hendrycksTest-high_school_chemistry", "hendrycksTest-high_school_computer_science", "hendrycksTest-high_school_european_history", "hendrycksTest-high_school_geography", "hendrycksTest-high_school_government_and_politics", "hendrycksTest-high_school_macroeconomics", "hendrycksTest-high_school_mathematics", "hendrycksTest-high_school_microeconomics", "hendrycksTest-high_school_physics", "hendrycksTest-high_school_psychology", "hendrycksTest-high_school_statistics", "hendrycksTest-high_school_us_history", "hendrycksTest-high_school_world_history", "hendrycksTest-human_aging", "hendrycksTest-human_sexuality", "hendrycksTest-international_law", "hendrycksTest-jurisprudence", "hendrycksTest-logical_fallacies", "hendrycksTest-machine_learning", "hendrycksTest-management", "hendrycksTest-marketing", "hendrycksTest-medical_genetics", "hendrycksTest-miscellaneous", "hendrycksTest-moral_disputes", "hendrycksTest-moral_scenarios", "hendrycksTest-nutrition", "hendrycksTest-philosophy", "hendrycksTest-prehistory", "hendrycksTest-professional_accounting", "hendrycksTest-professional_law", "hendrycksTest-professional_medicine", "hendrycksTest-professional_psychology", "hendrycksTest-public_relations", "hendrycksTest-security_studies", "hendrycksTest-sociology", "hendrycksTest-us_foreign_policy", "hendrycksTest-virology", "hendrycksTest-world_religions"] | |
task_dict = lm_eval.tasks.get_task_dict(tasks) | |
task_dict_items = [ | |
(name, task) | |
for name, task in task_dict.items() | |
if (task.has_validation_docs() or task.has_test_docs()) | |
] | |
versions = collections.defaultdict(dict) | |
# get lists of each type of request | |
for task_name, task in task_dict_items: | |
versions[task_name] = task.VERSION | |
# default to test doc, fall back to val doc if validation unavailable | |
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point | |
if task.has_test_docs(): | |
task_doc_func = task.test_docs | |
task_set = "test" # Required for caching in the decontamination | |
elif task.has_validation_docs(): | |
task_set = "val" # Required for caching in the decontamination | |
task_doc_func = task.validation_docs | |
else: | |
raise RuntimeError("Task has neither test_docs nor validation_docs") | |
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order | |
task_docs = list(task_doc_func()) | |
rnd = random.Random() | |
rnd.seed(42) | |
rnd.shuffle(task_docs) | |
i=0 | |
for task_doc in tqdm(task_docs): | |
print(task_name) | |
print(task_doc) | |
if ("result" in task_doc) and ("inference_time" in task_doc) and ("prompt" in task_doc) and ("result" in task_doc) and (task_doc['result'] != ""): | |
continue | |
query = task_doc["query"] if "query" in task_doc else "" | |
choices_list = "\n".join([str(number+1) + ". " + choice for number, choice in enumerate(task_doc["choices"])]) if "choices" in task_doc else "" | |
number_list = ",".join([str(number) for number in range(1,len(task_doc["choices"])+1)]) if "choices" in task_doc else "" | |
if (task_name == "hellaswag") : | |
prompt = "" | |
prompt_list = list(task.training_docs())[:10] | |
for prompt_item in prompt_list: | |
prompt_item_query = prompt_item["query"] | |
prompt_item_choices_list = "\n".join([str(number+1) + ". " + choice for number, choice in enumerate(prompt_item["choices"])]) | |
prompt_item_number_list = ",".join([str(number) for number in range(1,len(prompt_item["choices"])+1)]) | |
prompt_item_gold = prompt_item["gold"]+1 | |
prompt += f"""{prompt_item_query}...\n{prompt_item_choices_list}\nRespond with just one number only: {prompt_item_number_list}.\n{prompt_item_gold}\n\n""" | |
prompt += f"""{query}...\n{choices_list}\nRespond with just one number only: {number_list}. """ | |
elif (task_name == "arc_challenge"): | |
prompt = "" | |
prompt_list = list(task.training_docs())[:25] | |
for prompt_item in prompt_list: | |
prompt_item_query = prompt_item["query"] | |
prompt_item_choices_list = "\n".join([str(number+1) + ". " + choice for number, choice in enumerate(prompt_item["choices"])]) | |
prompt_item_number_list = ",".join([str(number) for number in range(1,len(prompt_item["choices"])+1)]) | |
prompt_item_gold = prompt_item["gold"]+1 | |
prompt += f"""{prompt_item_query}...\n{prompt_item_choices_list}\nRespond with just one number only: {prompt_item_number_list}.\n{prompt_item_gold}\n\n""" | |
prompt += f"""{query}...\n{choices_list}\nRespond with just one number only: {number_list}. """ | |
elif (task_name == "truthfulqa_mc"): | |
continue | |
prompt = "" | |
elif ("hendrycksTest" in task_name): | |
prompt = "" | |
prompt_list = list(task.test_docs())[:5] | |
for prompt_item in prompt_list: | |
prompt_item_query = prompt_item["query"] | |
prompt += f"""{prompt_item_query.replace("Answer:", "Respond with just one letter only: A, B, C, D:")}\n{["A", "B", "C", "D"][prompt_item["gold"]]}\n\n""" | |
prompt += query.replace("Answer:", "Respond with just one letter only: A, B, C, D:") | |
# print(prompt) | |
start = time.time() | |
task_doc["result"] = get_response(prompt) | |
end = time.time() | |
task_doc["inference_time"] = end - start | |
task_doc["prompt"] = prompt | |
task_doc["datetime"] = dt.now().strftime(format = "%Y-%m-%d %H:%M:%S") | |
print(task_doc["result"]) | |
i = i + 1 | |
if ((i % 100) / 1000 == 0): | |
with open(f"""_results/few-shot/{args.validator}/{task_name}_results.json""", "w") as final: | |
json.dump(task_docs, final) | |
with open(f"""_results/few-shot/{args.validator}/{task_name}_results.json""", "w") as final: | |
json.dump(task_docs, final) | |