""" Main command to run diverse genomic embedding benchmarks (DGEB) on a model. example command to run DGEB: python run_dgeb.py -m facebook/esm2_t6_8M_UR50D """ import argparse import logging import os import dgeb logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) ALL_TASK_NAMES = dgeb.get_all_task_names() ALL_MODEL_NAMES = dgeb.get_all_model_names() def main(): parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model", type=str, default=None, help=f"Model to evaluate. Choose from {ALL_MODEL_NAMES}", ) parser.add_argument( "-t", "--tasks", type=lambda s: [item for item in s.split(",")], default=None, help=f"Comma separated tasks to evaluate on. Choose from {ALL_TASK_NAMES} or do not specify to evaluate on all tasks", ) parser.add_argument( "-l", "--layers", type=str, default=None, help="Layer to evaluate. Comma separated list of integers or 'mid' and 'last'. Default is 'mid,last'", ) parser.add_argument( "--devices", type=str, default="0", help="Comma separated list of GPU device ids to use. Default is 0 (if GPUs are detected).", ) parser.add_argument( "--output_folder", type=str, default=None, help="Output directory for results. Will default to results/model_name if not set.", ) parser.add_argument( "-v", "--verbosity", type=int, default=2, help="Verbosity level" ) parser.add_argument( "-b", "--batch_size", type=int, default=64, help="Batch size for evaluation" ) parser.add_argument( "--max_seq_len", type=int, default=1024, help="Maximum sequence length for model, default is 1024.", ) parser.add_argument( "--pool_type", type=str, default="mean", help="Pooling type for model, choose from mean, max, cls, last. Default is mean.", ) args = parser.parse_args() # set logging based on verbosity level if args.verbosity == 0: logging.getLogger("geb").setLevel(logging.CRITICAL) elif args.verbosity == 1: logging.getLogger("geb").setLevel(logging.WARNING) elif args.verbosity == 2: logging.getLogger("geb").setLevel(logging.INFO) elif args.verbosity == 3: logging.getLogger("geb").setLevel(logging.DEBUG) if args.model is None: raise ValueError("Please specify a model using the -m or --model argument") # make sure that devices are comma separated list of integers try: devices = [int(device) for device in args.devices.split(",")] except ValueError: raise ValueError("Devices must be comma separated list of integers") layers = args.layers if layers: if layers not in ["mid", "last"]: # Layers should be list of integers. try: layers = [int(layer) for layer in layers.split(",")] except ValueError: raise ValueError("Layers must be a list of integers.") model_name = args.model.split("/")[-1] output_folder = args.output_folder if output_folder is None: output_folder = os.path.join("results", model_name) # create output folder if it does not exist if not os.path.exists(output_folder): os.makedirs(output_folder) logger.info(f"Results will be saved to {output_folder}") # Load the model by name. model = dgeb.get_model( model_name=args.model, layers=layers, devices=devices, max_seq_length=args.max_seq_len, batch_size=args.batch_size, pool_type=args.pool_type, ) all_tasks_for_modality = dgeb.get_tasks_by_modality(model.modality) if args.tasks: task_list = dgeb.get_tasks_by_name(args.tasks) if not all([task.metadata.modality == model.modality for task in task_list]): raise ValueError(f"Tasks must be one of {all_tasks_for_modality}") else: task_list = all_tasks_for_modality evaluation = dgeb.DGEB(tasks=task_list) _ = evaluation.run(model) if __name__ == "__main__": main()