Spaces:
Running
Running
""" | |
Main command to run diverse genomic embedding benchmarks (DGEB) on a model. | |
example command to run DGEB: | |
python run_dgeb.py -m facebook/esm2_t6_8M_UR50D | |
""" | |
import argparse | |
import logging | |
import os | |
import dgeb | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
ALL_TASK_NAMES = dgeb.get_all_task_names() | |
ALL_MODEL_NAMES = dgeb.get_all_model_names() | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-m", | |
"--model", | |
type=str, | |
default=None, | |
help=f"Model to evaluate. Choose from {ALL_MODEL_NAMES}", | |
) | |
parser.add_argument( | |
"-t", | |
"--tasks", | |
type=lambda s: [item for item in s.split(",")], | |
default=None, | |
help=f"Comma separated tasks to evaluate on. Choose from {ALL_TASK_NAMES} or do not specify to evaluate on all tasks", | |
) | |
parser.add_argument( | |
"-l", | |
"--layers", | |
type=str, | |
default=None, | |
help="Layer to evaluate. Comma separated list of integers or 'mid' and 'last'. Default is 'mid,last'", | |
) | |
parser.add_argument( | |
"--devices", | |
type=str, | |
default="0", | |
help="Comma separated list of GPU device ids to use. Default is 0 (if GPUs are detected).", | |
) | |
parser.add_argument( | |
"--output_folder", | |
type=str, | |
default=None, | |
help="Output directory for results. Will default to results/model_name if not set.", | |
) | |
parser.add_argument( | |
"-v", "--verbosity", type=int, default=2, help="Verbosity level" | |
) | |
parser.add_argument( | |
"-b", "--batch_size", type=int, default=64, help="Batch size for evaluation" | |
) | |
parser.add_argument( | |
"--max_seq_len", | |
type=int, | |
default=1024, | |
help="Maximum sequence length for model, default is 1024.", | |
) | |
parser.add_argument( | |
"--pool_type", | |
type=str, | |
default="mean", | |
help="Pooling type for model, choose from mean, max, cls, last. Default is mean.", | |
) | |
args = parser.parse_args() | |
# set logging based on verbosity level | |
if args.verbosity == 0: | |
logging.getLogger("geb").setLevel(logging.CRITICAL) | |
elif args.verbosity == 1: | |
logging.getLogger("geb").setLevel(logging.WARNING) | |
elif args.verbosity == 2: | |
logging.getLogger("geb").setLevel(logging.INFO) | |
elif args.verbosity == 3: | |
logging.getLogger("geb").setLevel(logging.DEBUG) | |
if args.model is None: | |
raise ValueError("Please specify a model using the -m or --model argument") | |
# make sure that devices are comma separated list of integers | |
try: | |
devices = [int(device) for device in args.devices.split(",")] | |
except ValueError: | |
raise ValueError("Devices must be comma separated list of integers") | |
layers = args.layers | |
if layers: | |
if layers not in ["mid", "last"]: | |
# Layers should be list of integers. | |
try: | |
layers = [int(layer) for layer in layers.split(",")] | |
except ValueError: | |
raise ValueError("Layers must be a list of integers.") | |
model_name = args.model.split("/")[-1] | |
output_folder = args.output_folder | |
if output_folder is None: | |
output_folder = os.path.join("results", model_name) | |
# create output folder if it does not exist | |
if not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
logger.info(f"Results will be saved to {output_folder}") | |
# Load the model by name. | |
model = dgeb.get_model( | |
model_name=args.model, | |
layers=layers, | |
devices=devices, | |
max_seq_length=args.max_seq_len, | |
batch_size=args.batch_size, | |
pool_type=args.pool_type, | |
) | |
all_tasks_for_modality = dgeb.get_tasks_by_modality(model.modality) | |
if args.tasks: | |
task_list = dgeb.get_tasks_by_name(args.tasks) | |
if not all([task.metadata.modality == model.modality for task in task_list]): | |
raise ValueError(f"Tasks must be one of {all_tasks_for_modality}") | |
else: | |
task_list = all_tasks_for_modality | |
evaluation = dgeb.DGEB(tasks=task_list) | |
_ = evaluation.run(model) | |
if __name__ == "__main__": | |
main() | |