File size: 6,630 Bytes
6360d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse


def get_parser(parser=None):
    if parser is None:
        parser = argparse.ArgumentParser()

    # Model
    #model_arg = parser.add_argument_group('Model')
    parser.add_argument('--n_layer',
                           type=int, default=12,
                           help='Mamba number of layers')
    parser.add_argument('--n_embd',
                           type=int, default=768,
                           help='Latent vector dimensionality')
    parser.add_argument('--dt_rank',
                           type=str, default='auto')
    parser.add_argument('--d_state',
                           type=int, default=16)
    parser.add_argument('--expand_factor',
                           type=int, default=2)
    parser.add_argument('--d_conv',
                           type=int, default=4)
    parser.add_argument('--dt_min',
                           type=float, default=0.001)
    parser.add_argument('--dt_max',
                           type=float, default=0.1)
    parser.add_argument('--dt_init',
                           type=str, default='random')
    parser.add_argument('--dt_scale',
                           type=float, default=1.0)
    parser.add_argument('--dt_init_floor',
                           type=float, default=1e-4)
    parser.add_argument('--bias',
                           type=int, default=0)
    parser.add_argument('--conv_bias',
                           type=int, default=1)


    # Train
    #train_arg = parser.add_argument_group('Train')
    parser.add_argument('--n_batch',
                           type=int, default=512,
                           help='Batch size')
    parser.add_argument('--checkpoint_every',
                           type=int, default=1000,
                           help='save checkpoint every x iterations')
    parser.add_argument('--clip_grad',
                           type=int, default=50,
                           help='Clip gradients to this value')
    parser.add_argument('--lr_start',
                           type=float, default=3 * 1e-4,
                           help='Initial lr value')
    parser.add_argument('--lr_end',
                           type=float, default=3 * 1e-4,
                           help='Maximum lr weight value')
    parser.add_argument('--lr_multiplier',
                           type=int, default=1,
                           help='lr weight multiplier')
    parser.add_argument('--device',
                        type=str, default='cuda',
                        help='Device to run: "cpu" or "cuda:<device number>"')
    parser.add_argument('--seed',
                        type=int, default=12345,
                        help='Seed')
    parser.add_argument('--lr_decoder',
                        type=float, default=1e-4,
                        help='Learning rate for decoder part')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--save_checkpoint_path', default='/data', help='checkpoint saving path')
    parser.add_argument('--load_checkpoint_path', default='', help='checkpoint loading path')

    #common_arg = parser.add_argument_group('Common')
    parser.add_argument('--vocab_load',
                            type=str, required=False,
                            help='Where to load the vocab')
    parser.add_argument('--n_samples',
                            type=int, required=False,
                            help='Number of samples to sample')
    parser.add_argument('--gen_save',
                            type=str, required=False,
                            help='Where to save the gen molecules')
    parser.add_argument("--max_len",
                            type=int, default=100,
                            help="Max of length of SMILES")
    parser.add_argument('--train_load',
                            type=str, required=False,
                            help='Where to load the model')
    parser.add_argument('--val_load',
                            type=str, required=False,
                            help='Where to load the model')
    parser.add_argument('--n_workers',
                            type=int, required=False, default=1,
                            help='Where to load the model')
    parser.add_argument('--max_epochs',
                            type=int, required=False, default=1,
                            help='max number of epochs')

    # debug() FINE TUNEING
    # parser.add_argument('--save_dir', type=str, required=True)
    parser.add_argument('--mode',
                           type=str, default='cls',
                           help='type of pooling to use')
    parser.add_argument("--dataset_length", type=int, default=None, required=False)
    parser.add_argument("--num_workers", type=int, default=0, required=False)
    parser.add_argument("--dropout", type=float, default=0.1, required=False)
    #parser.add_argument("--dims", type=int, nargs="*", default="", required=False)
    parser.add_argument(
        "--smiles_embedding",
        type=str,
        default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt",
    )
    # parser.add_argument("--train_pct", type=str, required=False, default="95")
    #parser.add_argument("--aug", type=int, required=True)
    parser.add_argument("--dataset_name", type=str, required=False, default="sol")
    parser.add_argument("--measure_name", type=str, required=False, default="measure")
    #parser.add_argument("--emb_type", type=str, required=True)
    #parser.add_argument("--checkpoints_folder", type=str, required=True)
    #parser.add_argument("--results_dir", type=str, required=True)
    #parser.add_argument("--patience_epochs", type=int, required=True)

    parser.add_argument(
        "--data_root",
        type=str,
        required=False,
        default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity",
    )
    # parser.add_argument("--use_bn", type=int, default=0)
    parser.add_argument("--use_linear", type=int, default=0)

    parser.add_argument("--lr", type=float, default=0.001)
    # parser.add_argument("--weight_decay", type=float, default=5e-4)
    # parser.add_argument("--val_check_interval", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=64)

    return parser
def parse_args():
    parser = get_parser()
    args = parser.parse_args()
    return args