File size: 5,675 Bytes
149cc2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

import jax
_ = jax.device_count() # ugly hack to prevent tpu comms to lock/race or smth smh

from typing import Tuple, Optional
import os
from argparse import ArgumentParser

from flax_trainer import FlaxTrainerUNetPseudo3D
from dataset import load_dataset

def train(
        dataset_path: str,
        model_path: str,
        output_dir: str,
        dataset_cache_dir: Optional[str] = None,
        from_pt: bool = True,
        convert2d: bool = False,
        only_temporal: bool = True,
        sample_size: Tuple[int, int] = (64, 64),
        lr: float = 5e-5,
        batch_size: int = 1,
        num_frames: int = 24,
        epochs: int = 10,
        warmup: float = 0.1,
        decay: float = 0.0,
        weight_decay: float = 1e-2,
        log_every_step: int = 50,
        save_every_epoch: int = 1,
        sample_every_epoch: int = 1,
        seed: int = 0,
        dtype: str = 'bfloat16',
        param_dtype: str = 'float32',
        use_memory_efficient_attention: bool = True,
        verbose: bool = True,
        use_wandb: bool = False
) -> None:
    log = lambda x: print(x) if verbose else None
    log('\n----------------')
    log('Init trainer')
    trainer = FlaxTrainerUNetPseudo3D(
            model_path = model_path,
            from_pt = from_pt,
            convert2d = convert2d,
            sample_size = sample_size,
            seed = seed,
            dtype = dtype,
            param_dtype = param_dtype,
            use_memory_efficient_attention = use_memory_efficient_attention,
            verbose = verbose,
            only_temporal = only_temporal
    )
    log('\n----------------')
    log('Init dataset')
    dataloader = load_dataset(
            dataset_path = dataset_path,
            model_path = model_path,
            cache_dir = dataset_cache_dir,
            batch_size = batch_size * trainer.num_devices,
            num_frames = num_frames,
            num_workers = min(trainer.num_devices * 2, os.cpu_count() - 1),
            as_numpy = True,
            shuffle = True
    )
    log('\n----------------')
    log('Train')
    if use_wandb:
        trainer.enable_wandb()
    trainer.train(
            dataloader = dataloader,
            epochs = epochs,
            num_frames = num_frames,
            log_every_step = log_every_step,
            save_every_epoch = save_every_epoch,
            sample_every_epoch = sample_every_epoch,
            lr = lr,
            warmup = warmup,
            decay = decay,
            weight_decay = weight_decay,
            output_dir = output_dir
    )
    log('\n----------------')
    log('Done')


if __name__ == '__main__':
    parser = ArgumentParser()
    bool_type = lambda x: x.lower() in ['true', '1', 'yes']
    parser.add_argument('-v', '--verbose', type = bool_type, default = True)
    parser.add_argument('-d', '--dataset_path', required = True)
    parser.add_argument('-m', '--model_path', required = True)
    parser.add_argument('-o', '--output_dir', required = True)
    parser.add_argument('-b', '--batch_size', type = int, default = 1)
    parser.add_argument('-f', '--num_frames', type = int, default = 24)
    parser.add_argument('-e', '--epochs', type = int, default = 2)
    parser.add_argument('--only_temporal', type = bool_type, default = True)
    parser.add_argument('--dataset_cache_dir', type = str, default = None)
    parser.add_argument('--from_pt', type = bool_type, default = True)
    parser.add_argument('--convert2d', type = bool_type, default = False)
    parser.add_argument('--lr', type = float, default = 1e-4)
    parser.add_argument('--warmup', type = float, default = 0.1)
    parser.add_argument('--decay', type = float, default = 0.0)
    parser.add_argument('--weight_decay', type = float, default = 1e-2)
    parser.add_argument('--sample_size', type = int, nargs = 2, default = [64, 64])
    parser.add_argument('--log_every_step', type = int, default = 250)
    parser.add_argument('--save_every_epoch', type = int, default = 1)
    parser.add_argument('--sample_every_epoch', type = int, default = 1)
    parser.add_argument('--seed', type = int, default = 0)
    parser.add_argument('--use_memory_efficient_attention', type = bool_type, default = True)
    parser.add_argument('--dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'bfloat16')
    parser.add_argument('--param_dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'float32')
    parser.add_argument('--wandb', type = bool_type, default = False)
    args = parser.parse_args()
    args.sample_size = tuple(args.sample_size)
    if args.verbose:
        print(args)
    train(
            dataset_path = args.dataset_path,
            model_path = args.model_path,
            from_pt = args.from_pt,
            convert2d = args.convert2d,
            only_temporal = args.only_temporal,
            output_dir = args.output_dir,
            dataset_cache_dir = args.dataset_cache_dir,
            batch_size = args.batch_size,
            num_frames = args.num_frames,
            epochs = args.epochs,
            lr = args.lr,
            warmup = args.warmup,
            decay = args.decay,
            weight_decay = args.weight_decay,
            sample_size = args.sample_size,
            seed = args.seed,
            dtype = args.dtype,
            param_dtype = args.param_dtype,
            use_memory_efficient_attention = args.use_memory_efficient_attention,
            log_every_step = args.log_every_step,
            save_every_epoch = args.save_every_epoch,
            sample_every_epoch = args.sample_every_epoch,
            verbose = args.verbose,
            use_wandb = args.wandb
    )