File size: 6,630 Bytes
6360d19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import argparse
def get_parser(parser=None):
if parser is None:
parser = argparse.ArgumentParser()
# Model
#model_arg = parser.add_argument_group('Model')
parser.add_argument('--n_layer',
type=int, default=12,
help='Mamba number of layers')
parser.add_argument('--n_embd',
type=int, default=768,
help='Latent vector dimensionality')
parser.add_argument('--dt_rank',
type=str, default='auto')
parser.add_argument('--d_state',
type=int, default=16)
parser.add_argument('--expand_factor',
type=int, default=2)
parser.add_argument('--d_conv',
type=int, default=4)
parser.add_argument('--dt_min',
type=float, default=0.001)
parser.add_argument('--dt_max',
type=float, default=0.1)
parser.add_argument('--dt_init',
type=str, default='random')
parser.add_argument('--dt_scale',
type=float, default=1.0)
parser.add_argument('--dt_init_floor',
type=float, default=1e-4)
parser.add_argument('--bias',
type=int, default=0)
parser.add_argument('--conv_bias',
type=int, default=1)
# Train
#train_arg = parser.add_argument_group('Train')
parser.add_argument('--n_batch',
type=int, default=512,
help='Batch size')
parser.add_argument('--checkpoint_every',
type=int, default=1000,
help='save checkpoint every x iterations')
parser.add_argument('--clip_grad',
type=int, default=50,
help='Clip gradients to this value')
parser.add_argument('--lr_start',
type=float, default=3 * 1e-4,
help='Initial lr value')
parser.add_argument('--lr_end',
type=float, default=3 * 1e-4,
help='Maximum lr weight value')
parser.add_argument('--lr_multiplier',
type=int, default=1,
help='lr weight multiplier')
parser.add_argument('--device',
type=str, default='cuda',
help='Device to run: "cpu" or "cuda:<device number>"')
parser.add_argument('--seed',
type=int, default=12345,
help='Seed')
parser.add_argument('--lr_decoder',
type=float, default=1e-4,
help='Learning rate for decoder part')
parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
parser.add_argument('--save_checkpoint_path', default='/data', help='checkpoint saving path')
parser.add_argument('--load_checkpoint_path', default='', help='checkpoint loading path')
#common_arg = parser.add_argument_group('Common')
parser.add_argument('--vocab_load',
type=str, required=False,
help='Where to load the vocab')
parser.add_argument('--n_samples',
type=int, required=False,
help='Number of samples to sample')
parser.add_argument('--gen_save',
type=str, required=False,
help='Where to save the gen molecules')
parser.add_argument("--max_len",
type=int, default=100,
help="Max of length of SMILES")
parser.add_argument('--train_load',
type=str, required=False,
help='Where to load the model')
parser.add_argument('--val_load',
type=str, required=False,
help='Where to load the model')
parser.add_argument('--n_workers',
type=int, required=False, default=1,
help='Where to load the model')
parser.add_argument('--max_epochs',
type=int, required=False, default=1,
help='max number of epochs')
# debug() FINE TUNEING
# parser.add_argument('--save_dir', type=str, required=True)
parser.add_argument('--mode',
type=str, default='cls',
help='type of pooling to use')
parser.add_argument("--dataset_length", type=int, default=None, required=False)
parser.add_argument("--num_workers", type=int, default=0, required=False)
parser.add_argument("--dropout", type=float, default=0.1, required=False)
#parser.add_argument("--dims", type=int, nargs="*", default="", required=False)
parser.add_argument(
"--smiles_embedding",
type=str,
default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt",
)
# parser.add_argument("--train_pct", type=str, required=False, default="95")
#parser.add_argument("--aug", type=int, required=True)
parser.add_argument("--dataset_name", type=str, required=False, default="sol")
parser.add_argument("--measure_name", type=str, required=False, default="measure")
#parser.add_argument("--emb_type", type=str, required=True)
#parser.add_argument("--checkpoints_folder", type=str, required=True)
#parser.add_argument("--results_dir", type=str, required=True)
#parser.add_argument("--patience_epochs", type=int, required=True)
parser.add_argument(
"--data_root",
type=str,
required=False,
default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity",
)
# parser.add_argument("--use_bn", type=int, default=0)
parser.add_argument("--use_linear", type=int, default=0)
parser.add_argument("--lr", type=float, default=0.001)
# parser.add_argument("--weight_decay", type=float, default=5e-4)
# parser.add_argument("--val_check_interval", type=float, default=1.0)
parser.add_argument("--batch_size", type=int, default=64)
return parser
def parse_args():
parser = get_parser()
args = parser.parse_args()
return args
|