fffiloni commited on
Commit
9de4480
·
verified ·
1 Parent(s): 2ae6665

split model setup and task execution

Browse files
Files changed (1) hide show
  1. main.py +14 -4
main.py CHANGED
@@ -14,7 +14,8 @@ from rewards import get_reward_losses
14
  from training import LatentNoiseTrainer, get_optimizer
15
 
16
 
17
- def main(args, progress_callback=None):
 
18
  seed_everything(args.seed)
19
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
20
  # Set up logging and name settings
@@ -92,6 +93,10 @@ def main(args, progress_callback=None):
92
  )
93
  enable_grad = not args.no_optim
94
 
 
 
 
 
95
  if args.task == "single":
96
  init_latents = torch.randn(shape, device=device, dtype=dtype)
97
  latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
@@ -269,7 +274,12 @@ def main(args, progress_callback=None):
269
  # log total rewards
270
  logging.info(f"Mean initial rewards: {total_init_rewards}")
271
  logging.info(f"Mean best rewards: {total_best_rewards}")
272
-
273
- if __name__ == "__main__":
274
  args = parse_args()
275
- main(args)
 
 
 
 
 
 
14
  from training import LatentNoiseTrainer, get_optimizer
15
 
16
 
17
+ def setup(args):
18
+ #args = parse_args()
19
  seed_everything(args.seed)
20
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
21
  # Set up logging and name settings
 
93
  )
94
  enable_grad = not args.no_optim
95
 
96
+ return args, trainer, device, dtype, shape, enable_grad, settings
97
+
98
+ def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback=None):
99
+ #args = parse_args()
100
  if args.task == "single":
101
  init_latents = torch.randn(shape, device=device, dtype=dtype)
102
  latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
 
274
  # log total rewards
275
  logging.info(f"Mean initial rewards: {total_init_rewards}")
276
  logging.info(f"Mean best rewards: {total_best_rewards}")
277
+
278
+ def main():
279
  args = parse_args()
280
+ args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
281
+ execute_task(args, trainer, device, dtype, shape, enable_grad, settings)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ main()