|
from functools import partial |
|
from jax import random |
|
import jax.numpy as np |
|
from jax.scipy.linalg import block_diag |
|
import wandb |
|
|
|
from .train_helpers import create_train_state, reduce_lr_on_plateau,\ |
|
linear_warmup, cosine_annealing, constant_lr, train_epoch, validate |
|
from .dataloading import Datasets |
|
from .seq_model import BatchClassificationModel, RetrievalModel |
|
from .ssm import init_S5SSM |
|
from .ssm_init import make_DPLR_HiPPO |
|
|
|
|
|
def train(args): |
|
""" |
|
Main function to train over a certain number of epochs |
|
""" |
|
|
|
best_test_loss = 100000000 |
|
best_test_acc = -10000.0 |
|
|
|
if args.USE_WANDB: |
|
|
|
wandb.init(project=args.wandb_project, job_type='model_training', config=vars(args), entity=args.wandb_entity) |
|
else: |
|
wandb.init(mode='offline') |
|
|
|
ssm_size = args.ssm_size_base |
|
ssm_lr = args.ssm_lr_base |
|
|
|
|
|
block_size = int(ssm_size / args.blocks) |
|
wandb.log({"block_size": block_size}) |
|
|
|
|
|
lr = args.lr_factor * ssm_lr |
|
|
|
|
|
print("[*] Setting Randomness...") |
|
key = random.PRNGKey(args.jax_seed) |
|
init_rng, train_rng = random.split(key, num=2) |
|
|
|
|
|
create_dataset_fn = Datasets[args.dataset] |
|
|
|
|
|
if args.dataset in ["imdb-classification", "listops-classification", "aan-classification"]: |
|
padded = True |
|
if args.dataset in ["aan-classification"]: |
|
|
|
retrieval = True |
|
print("Using retrieval model for document matching") |
|
else: |
|
retrieval = False |
|
|
|
else: |
|
padded = False |
|
retrieval = False |
|
|
|
|
|
if args.dataset in ["speech35-classification"]: |
|
speech = True |
|
print("Will evaluate on both resolutions for speech task") |
|
else: |
|
speech = False |
|
|
|
|
|
init_rng, key = random.split(init_rng, num=2) |
|
trainloader, valloader, testloader, aux_dataloaders, n_classes, seq_len, in_dim, train_size = \ |
|
create_dataset_fn(args.dir_name, seed=args.jax_seed, bsz=args.bsz) |
|
|
|
print(f"[*] Starting S5 Training on `{args.dataset}` =>> Initializing...") |
|
|
|
|
|
Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size) |
|
|
|
if args.conj_sym: |
|
block_size = block_size // 2 |
|
ssm_size = ssm_size // 2 |
|
|
|
Lambda = Lambda[:block_size] |
|
V = V[:, :block_size] |
|
Vc = V.conj().T |
|
|
|
|
|
|
|
Lambda = (Lambda * np.ones((args.blocks, block_size))).ravel() |
|
V = block_diag(*([V] * args.blocks)) |
|
Vinv = block_diag(*([Vc] * args.blocks)) |
|
|
|
print("Lambda.shape={}".format(Lambda.shape)) |
|
print("V.shape={}".format(V.shape)) |
|
print("Vinv.shape={}".format(Vinv.shape)) |
|
|
|
ssm_init_fn = init_S5SSM(H=args.d_model, |
|
P=ssm_size, |
|
Lambda_re_init=Lambda.real, |
|
Lambda_im_init=Lambda.imag, |
|
V=V, |
|
Vinv=Vinv, |
|
C_init=args.C_init, |
|
discretization=args.discretization, |
|
dt_min=args.dt_min, |
|
dt_max=args.dt_max, |
|
conj_sym=args.conj_sym, |
|
clip_eigs=args.clip_eigs, |
|
bidirectional=args.bidirectional) |
|
|
|
if retrieval: |
|
|
|
print("Using Retrieval head for {} task".format(args.dataset)) |
|
model_cls = partial( |
|
RetrievalModel, |
|
ssm=ssm_init_fn, |
|
d_output=n_classes, |
|
d_model=args.d_model, |
|
n_layers=args.n_layers, |
|
padded=padded, |
|
activation=args.activation_fn, |
|
dropout=args.p_dropout, |
|
prenorm=args.prenorm, |
|
batchnorm=args.batchnorm, |
|
bn_momentum=args.bn_momentum, |
|
) |
|
|
|
else: |
|
model_cls = partial( |
|
BatchClassificationModel, |
|
ssm=ssm_init_fn, |
|
d_output=n_classes, |
|
d_model=args.d_model, |
|
n_layers=args.n_layers, |
|
padded=padded, |
|
activation=args.activation_fn, |
|
dropout=args.p_dropout, |
|
mode=args.mode, |
|
prenorm=args.prenorm, |
|
batchnorm=args.batchnorm, |
|
bn_momentum=args.bn_momentum, |
|
) |
|
|
|
|
|
state = create_train_state(model_cls, |
|
init_rng, |
|
padded, |
|
retrieval, |
|
in_dim=in_dim, |
|
bsz=args.bsz, |
|
seq_len=seq_len, |
|
weight_decay=args.weight_decay, |
|
batchnorm=args.batchnorm, |
|
opt_config=args.opt_config, |
|
ssm_lr=ssm_lr, |
|
lr=lr, |
|
dt_global=args.dt_global) |
|
|
|
|
|
best_loss, best_acc, best_epoch = 100000000, -100000000.0, 0 |
|
count, best_val_loss = 0, 100000000 |
|
lr_count, opt_acc = 0, -100000000.0 |
|
step = 0 |
|
steps_per_epoch = int(train_size/args.bsz) |
|
for epoch in range(args.epochs): |
|
print(f"[*] Starting Training Epoch {epoch + 1}...") |
|
|
|
if epoch < args.warmup_end: |
|
print("using linear warmup for epoch {}".format(epoch+1)) |
|
decay_function = linear_warmup |
|
end_step = steps_per_epoch * args.warmup_end |
|
|
|
elif args.cosine_anneal: |
|
print("using cosine annealing for epoch {}".format(epoch+1)) |
|
decay_function = cosine_annealing |
|
|
|
end_step = steps_per_epoch * args.epochs - (steps_per_epoch * args.warmup_end) |
|
else: |
|
print("using constant lr for epoch {}".format(epoch+1)) |
|
decay_function = constant_lr |
|
end_step = None |
|
|
|
|
|
|
|
lr_params = (decay_function, ssm_lr, lr, step, end_step, args.opt_config, args.lr_min) |
|
|
|
train_rng, skey = random.split(train_rng) |
|
state, train_loss, step = train_epoch(state, |
|
skey, |
|
model_cls, |
|
trainloader, |
|
seq_len, |
|
in_dim, |
|
args.batchnorm, |
|
lr_params) |
|
|
|
if valloader is not None: |
|
print(f"[*] Running Epoch {epoch + 1} Validation...") |
|
val_loss, val_acc = validate(state, |
|
model_cls, |
|
valloader, |
|
seq_len, |
|
in_dim, |
|
args.batchnorm) |
|
|
|
print(f"[*] Running Epoch {epoch + 1} Test...") |
|
test_loss, test_acc = validate(state, |
|
model_cls, |
|
testloader, |
|
seq_len, |
|
in_dim, |
|
args.batchnorm) |
|
|
|
print(f"\n=>> Epoch {epoch + 1} Metrics ===") |
|
print( |
|
f"\tTrain Loss: {train_loss:.5f} -- Val Loss: {val_loss:.5f} --Test Loss: {test_loss:.5f} --" |
|
f" Val Accuracy: {val_acc:.4f}" |
|
f" Test Accuracy: {test_acc:.4f}" |
|
) |
|
|
|
else: |
|
|
|
print(f"[*] Running Epoch {epoch + 1} Test...") |
|
val_loss, val_acc = validate(state, |
|
model_cls, |
|
testloader, |
|
seq_len, |
|
in_dim, |
|
args.batchnorm) |
|
|
|
print(f"\n=>> Epoch {epoch + 1} Metrics ===") |
|
print( |
|
f"\tTrain Loss: {train_loss:.5f} --Test Loss: {val_loss:.5f} --" |
|
f" Test Accuracy: {val_acc:.4f}" |
|
) |
|
|
|
|
|
if val_loss < best_val_loss: |
|
count = 0 |
|
best_val_loss = val_loss |
|
else: |
|
count += 1 |
|
|
|
if val_acc > best_acc: |
|
|
|
count = 0 |
|
best_loss, best_acc, best_epoch = val_loss, val_acc, epoch |
|
if valloader is not None: |
|
best_test_loss, best_test_acc = test_loss, test_acc |
|
else: |
|
best_test_loss, best_test_acc = best_loss, best_acc |
|
|
|
|
|
if speech: |
|
|
|
print(f"[*] Running Epoch {epoch + 1} Res 2 Validation...") |
|
val2_loss, val2_acc = validate(state, |
|
model_cls, |
|
aux_dataloaders['valloader2'], |
|
int(seq_len // 2), |
|
in_dim, |
|
args.batchnorm, |
|
step_rescale=2.0) |
|
|
|
print(f"[*] Running Epoch {epoch + 1} Res 2 Test...") |
|
test2_loss, test2_acc = validate(state, model_cls, aux_dataloaders['testloader2'], int(seq_len // 2), in_dim, args.batchnorm, step_rescale=2.0) |
|
print(f"\n=>> Epoch {epoch + 1} Res 2 Metrics ===") |
|
print( |
|
f"\tVal2 Loss: {val2_loss:.5f} --Test2 Loss: {test2_loss:.5f} --" |
|
f" Val Accuracy: {val2_acc:.4f}" |
|
f" Test Accuracy: {test2_acc:.4f}" |
|
) |
|
|
|
|
|
input = lr, ssm_lr, lr_count, val_acc, opt_acc |
|
lr, ssm_lr, lr_count, opt_acc = reduce_lr_on_plateau(input, factor=args.reduce_factor, patience=args.lr_patience, lr_min=args.lr_min) |
|
|
|
|
|
print( |
|
f"\tBest Val Loss: {best_loss:.5f} -- Best Val Accuracy:" |
|
f" {best_acc:.4f} at Epoch {best_epoch + 1}\n" |
|
f"\tBest Test Loss: {best_test_loss:.5f} -- Best Test Accuracy:" |
|
f" {best_test_acc:.4f} at Epoch {best_epoch + 1}\n" |
|
) |
|
|
|
if valloader is not None: |
|
if speech: |
|
wandb.log( |
|
{ |
|
"Training Loss": train_loss, |
|
"Val loss": val_loss, |
|
"Val Accuracy": val_acc, |
|
"Test Loss": test_loss, |
|
"Test Accuracy": test_acc, |
|
"Val2 loss": val2_loss, |
|
"Val2 Accuracy": val2_acc, |
|
"Test2 Loss": test2_loss, |
|
"Test2 Accuracy": test2_acc, |
|
"count": count, |
|
"Learning rate count": lr_count, |
|
"Opt acc": opt_acc, |
|
"lr": state.opt_state.inner_states['regular'].inner_state.hyperparams['learning_rate'], |
|
"ssm_lr": state.opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'] |
|
} |
|
) |
|
else: |
|
wandb.log( |
|
{ |
|
"Training Loss": train_loss, |
|
"Val loss": val_loss, |
|
"Val Accuracy": val_acc, |
|
"Test Loss": test_loss, |
|
"Test Accuracy": test_acc, |
|
"count": count, |
|
"Learning rate count": lr_count, |
|
"Opt acc": opt_acc, |
|
"lr": state.opt_state.inner_states['regular'].inner_state.hyperparams['learning_rate'], |
|
"ssm_lr": state.opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'] |
|
} |
|
) |
|
|
|
else: |
|
wandb.log( |
|
{ |
|
"Training Loss": train_loss, |
|
"Val loss": val_loss, |
|
"Val Accuracy": val_acc, |
|
"count": count, |
|
"Learning rate count": lr_count, |
|
"Opt acc": opt_acc, |
|
"lr": state.opt_state.inner_states['regular'].inner_state.hyperparams['learning_rate'], |
|
"ssm_lr": state.opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'] |
|
} |
|
) |
|
wandb.run.summary["Best Val Loss"] = best_loss |
|
wandb.run.summary["Best Val Accuracy"] = best_acc |
|
wandb.run.summary["Best Epoch"] = best_epoch |
|
wandb.run.summary["Best Test Loss"] = best_test_loss |
|
wandb.run.summary["Best Test Accuracy"] = best_test_acc |
|
|
|
if count > args.early_stop_patience: |
|
break |
|
|