Syed Abdul Gaffar Shakhadri commited on
Commit
dbbc0af
·
unverified ·
1 Parent(s): 6059ecb

added get_inference_config() for inferencing

Browse files
Files changed (1) hide show
  1. config.py +26 -5
config.py CHANGED
@@ -19,7 +19,7 @@ _C.BASE = ['']
19
  # -----------------------------------------------------------------------------
20
  _C.DATA = CN()
21
  # Batch size for a single GPU, could be overwritten by command line argument
22
- _C.DATA.BATCH_SIZE = 128
23
  # Path to dataset, could be overwritten by command line argument
24
  _C.DATA.DATA_PATH = ''
25
  # Dataset name
@@ -37,7 +37,7 @@ _C.DATA.CACHE_MODE = 'part'
37
  # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
38
  _C.DATA.PIN_MEMORY = True
39
  # Number of data loading threads
40
- _C.DATA.NUM_WORKERS = 8
41
  # hdfs data dir
42
  _C.DATA.TRAIN_PATH = None
43
  _C.DATA.VAL_PATH = None
@@ -63,7 +63,7 @@ _C.MODEL.NAME = ''
63
  # Checkpoint to resume, could be overwritten by command line argument
64
  _C.MODEL.RESUME = ''
65
  # Number of classes, overwritten in data preparation
66
- _C.MODEL.NUM_CLASSES = 1000
67
  # Dropout rate
68
  _C.MODEL.DROP_RATE = 0.0
69
  # Drop path rate
@@ -89,9 +89,9 @@ _C.TRAIN.START_EPOCH = 0
89
  _C.TRAIN.EPOCHS = 300
90
  _C.TRAIN.WARMUP_EPOCHS = 20
91
  _C.TRAIN.WEIGHT_DECAY = 0.05
92
- _C.TRAIN.BASE_LR = 5e-4
93
  _C.TRAIN.WARMUP_LR = 5e-7
94
- _C.TRAIN.MIN_LR = 5e-6
95
  # Clip gradient norm
96
  _C.TRAIN.CLIP_GRAD = 5.0
97
  # Auto resume from latest checkpoint
@@ -271,3 +271,24 @@ def get_config(args):
271
  update_config(config, args)
272
 
273
  return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # -----------------------------------------------------------------------------
20
  _C.DATA = CN()
21
  # Batch size for a single GPU, could be overwritten by command line argument
22
+ _C.DATA.BATCH_SIZE = 32
23
  # Path to dataset, could be overwritten by command line argument
24
  _C.DATA.DATA_PATH = ''
25
  # Dataset name
 
37
  # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
38
  _C.DATA.PIN_MEMORY = True
39
  # Number of data loading threads
40
+ _C.DATA.NUM_WORKERS = 4
41
  # hdfs data dir
42
  _C.DATA.TRAIN_PATH = None
43
  _C.DATA.VAL_PATH = None
 
63
  # Checkpoint to resume, could be overwritten by command line argument
64
  _C.MODEL.RESUME = ''
65
  # Number of classes, overwritten in data preparation
66
+ _C.MODEL.NUM_CLASSES = 200#1000
67
  # Dropout rate
68
  _C.MODEL.DROP_RATE = 0.0
69
  # Drop path rate
 
89
  _C.TRAIN.EPOCHS = 300
90
  _C.TRAIN.WARMUP_EPOCHS = 20
91
  _C.TRAIN.WEIGHT_DECAY = 0.05
92
+ _C.TRAIN.BASE_LR = 1e-4 # 5e-4
93
  _C.TRAIN.WARMUP_LR = 5e-7
94
+ _C.TRAIN.MIN_LR = 1e-5 # 5e-6
95
  # Clip gradient norm
96
  _C.TRAIN.CLIP_GRAD = 5.0
97
  # Auto resume from latest checkpoint
 
271
  update_config(config, args)
272
 
273
  return config
274
+
275
+
276
+ ################### For Inferencing ####################
277
+ def update_inference_config(config, args):
278
+ _update_config_from_file(config, args.cfg)
279
+
280
+ config.defrost()
281
+
282
+ config.freeze()
283
+
284
+
285
+ def get_inference_config(cfg_path):
286
+ """Get a yacs CfgNode object with default values."""
287
+ # Return a clone so that the defaults will not be altered
288
+ # This is for the "local variable" use pattern
289
+ config = _C.clone()
290
+ update_inference_config(config, cfg_path)
291
+
292
+ return config
293
+
294
+ ################### For Inferencing ####################