DGEB / scripts /eval_all_models.py
Joshua Kravitz
Initial commit
e284167
raw
history blame
2.83 kB
"""Script to replicate results from the DGEB paper."""
import torch
import dgeb
from functools import partial
ALL_DEVICES = list(range(torch.cuda.device_count()))
DEFAULT_BATCH_SIZE = 64
DEFAULT_SEQ_LEN = 1024
get_model = partial(
dgeb.get_model,
devices=ALL_DEVICES,
batch_size=DEFAULT_BATCH_SIZE,
max_seq_length=DEFAULT_SEQ_LEN,
)
def main():
######################### Protein Models #########################
protein_tasks = dgeb.get_tasks_by_modality(dgeb.Modality.PROTEIN)
protein_evaluation = dgeb.DGEB(tasks=protein_tasks)
# ESM models.
protein_evaluation.run(get_model("facebook/esm2_t6_8M_UR50D"))
protein_evaluation.run(get_model("facebook/esm2_t12_35M_UR50D"))
protein_evaluation.run(get_model("facebook/esm2_t30_150M_UR50D"))
protein_evaluation.run(get_model("facebook/esm2_t33_650M_UR50D", batch_size=32))
protein_evaluation.run(get_model("facebook/esm2_t36_3B_UR50D", batch_size=1))
# ESM3 models.
protein_evaluation.run(get_model("esm3_sm_open_v1", batch_size=1, devices=[0]))
# ProtT5 models.
protein_evaluation.run(get_model("Rostlab/prot_t5_xl_uniref50", batch_size=32))
protein_evaluation.run(get_model("Rostlab/prot_t5_xl_bfd", batch_size=32))
# ProGen2 models.
protein_evaluation.run(get_model("hugohrban/progen2-small"))
protein_evaluation.run(get_model("hugohrban/progen2-medium", batch_size=32))
protein_evaluation.run(get_model("hugohrban/progen2-large", batch_size=1))
protein_evaluation.run(get_model("hugohrban/progen2-xlarge", batch_size=1))
######################### DNA Models #########################
dna_tasks = dgeb.get_tasks_by_modality(dgeb.Modality.DNA)
dna_evaluation = dgeb.DGEB(tasks=dna_tasks)
# Evo models
dna_evaluation.run(
get_model(
"togethercomputer/evo-1-8k-base", batch_size=1, seq_len=8192, devices=[0]
)
)
# 131k will OOM so we use half this length.
evo_131k_max_seq_len = int(131072 / 2)
dna_evaluation.run(
get_model(
"togethercomputer/evo-1-131k-base",
batch_size=1,
seq_len=evo_131k_max_seq_len,
devices=[0],
)
)
# Nucleotide Transformer models.
dna_evaluation.run(
get_model("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species")
)
dna_evaluation.run(
get_model("InstaDeepAI/nucleotide-transformer-v2-100m-multi-species")
)
dna_evaluation.run(
get_model("InstaDeepAI/nucleotide-transformer-v2-250m-multi-species")
)
dna_evaluation.run(
get_model("InstaDeepAI/nucleotide-transformer-v2-500m-multi-species")
)
dna_evaluation.run(
get_model("InstaDeepAI/nucleotide-transformer-2.5b-multi-species", batch_size=1)
)
if __name__ == "__main__":
main()