File size: 4,258 Bytes
e284167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Main command to run diverse genomic embedding benchmarks (DGEB) on a model.
example command to run DGEB:
python run_dgeb.py -m facebook/esm2_t6_8M_UR50D
"""

import argparse
import logging
import os

import dgeb

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ALL_TASK_NAMES = dgeb.get_all_task_names()
ALL_MODEL_NAMES = dgeb.get_all_model_names()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default=None,
        help=f"Model to evaluate. Choose from {ALL_MODEL_NAMES}",
    )
    parser.add_argument(
        "-t",
        "--tasks",
        type=lambda s: [item for item in s.split(",")],
        default=None,
        help=f"Comma separated tasks to evaluate on. Choose from {ALL_TASK_NAMES} or do not specify to evaluate on all tasks",
    )
    parser.add_argument(
        "-l",
        "--layers",
        type=str,
        default=None,
        help="Layer to evaluate. Comma separated list of integers or 'mid' and 'last'. Default is 'mid,last'",
    )
    parser.add_argument(
        "--devices",
        type=str,
        default="0",
        help="Comma separated list of GPU device ids to use. Default is 0 (if GPUs are detected).",
    )
    parser.add_argument(
        "--output_folder",
        type=str,
        default=None,
        help="Output directory for results. Will default to results/model_name if not set.",
    )
    parser.add_argument(
        "-v", "--verbosity", type=int, default=2, help="Verbosity level"
    )
    parser.add_argument(
        "-b", "--batch_size", type=int, default=64, help="Batch size for evaluation"
    )
    parser.add_argument(
        "--max_seq_len",
        type=int,
        default=1024,
        help="Maximum sequence length for model, default is 1024.",
    )
    parser.add_argument(
        "--pool_type",
        type=str,
        default="mean",
        help="Pooling type for model, choose from mean, max, cls, last. Default is mean.",
    )

    args = parser.parse_args()

    # set logging based on verbosity level
    if args.verbosity == 0:
        logging.getLogger("geb").setLevel(logging.CRITICAL)
    elif args.verbosity == 1:
        logging.getLogger("geb").setLevel(logging.WARNING)
    elif args.verbosity == 2:
        logging.getLogger("geb").setLevel(logging.INFO)
    elif args.verbosity == 3:
        logging.getLogger("geb").setLevel(logging.DEBUG)

    if args.model is None:
        raise ValueError("Please specify a model using the -m or --model argument")

    # make sure that devices are comma separated list of integers
    try:
        devices = [int(device) for device in args.devices.split(",")]
    except ValueError:
        raise ValueError("Devices must be comma separated list of integers")

    layers = args.layers
    if layers:
        if layers not in ["mid", "last"]:
            # Layers should be list of integers.
            try:
                layers = [int(layer) for layer in layers.split(",")]
            except ValueError:
                raise ValueError("Layers must be a list of integers.")

    model_name = args.model.split("/")[-1]
    output_folder = args.output_folder
    if output_folder is None:
        output_folder = os.path.join("results", model_name)
        # create output folder if it does not exist
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
        logger.info(f"Results will be saved to {output_folder}")

    # Load the model by name.
    model = dgeb.get_model(
        model_name=args.model,
        layers=layers,
        devices=devices,
        max_seq_length=args.max_seq_len,
        batch_size=args.batch_size,
        pool_type=args.pool_type,
    )

    all_tasks_for_modality = dgeb.get_tasks_by_modality(model.modality)

    if args.tasks:
        task_list = dgeb.get_tasks_by_name(args.tasks)
        if not all([task.metadata.modality == model.modality for task in task_list]):
            raise ValueError(f"Tasks must be one of {all_tasks_for_modality}")
    else:
        task_list = all_tasks_for_modality
    evaluation = dgeb.DGEB(tasks=task_list)
    _ = evaluation.run(model)


if __name__ == "__main__":
    main()