import os import argparse from transformers import set_seed from src.scripts.mytokenizers import Tokenizer from src.improved_diffusion import gaussian_diffusion as gd from src.improved_diffusion.respace import SpacedDiffusion from src.improved_diffusion import dist_util from src.improved_diffusion.transformer_model import TransformerNetModel from src.improved_diffusion.resample import create_named_schedule_sampler from src.improved_diffusion.script_util import model_and_diffusion_defaults from src.improved_diffusion.script_util import add_dict_to_argparser from src.improved_diffusion.train_util import TrainLoop import torch.distributed as dist import wandb from src.scripts.mydatasets import get_dataloader, Lang2molDataset_train import warnings import torch.multiprocessing as mp def main_worker(rank, world_size): args = create_argparser().parse_args() set_seed(42) wandb.login(key=args.wandb_token) wandb.init( project="ACL_Lang2Mol", config=args.__dict__, ) dist_util.setup_dist(rank, world_size) tokenizer = Tokenizer() model = TransformerNetModel( in_channels=args.model_in_channels, model_channels=args.model_model_channels, dropout=args.model_dropout, vocab_size=len(tokenizer), hidden_size=args.model_hidden_size, num_attention_heads=args.model_num_attention_heads, num_hidden_layers=args.model_num_hidden_layers, ) if args.model_path != "": model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu") ) model.train() print("Total params:", sum(p.numel() for p in model.parameters())) print( "Total trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad), ) print("Tokenizer vocab length:", len(tokenizer)) diffusion = SpacedDiffusion( use_timesteps=[i for i in range(args.diffusion_steps)], betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps), model_mean_type=(gd.ModelMeanType.START_X), model_var_type=((gd.ModelVarType.FIXED_LARGE)), loss_type=gd.LossType.E2E_MSE, rescale_timesteps=True, model_arch="transformer", training_mode="e2e", ) schedule_sampler = create_named_schedule_sampler("uniform", diffusion) print("Loading data...") train_dataset = Lang2molDataset_train( dir=args.dataset_path, tokenizer=tokenizer, split="train", corrupt_prob=0.0, token_max_length=512, dataset_name=args.dataset_name, ) dataloader = get_dataloader(train_dataset, args.batch_size, rank, world_size) print("Finish loading data") TrainLoop( model=model, diffusion=diffusion, data=dataloader, batch_size=args.batch_size, microbatch=args.microbatch, lr=args.lr, ema_rate=args.ema_rate, log_interval=args.log_interval, save_interval=args.save_interval, resume_checkpoint=args.resume_checkpoint, use_fp16=args.use_fp16, fp16_scale_growth=args.fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=args.weight_decay, lr_anneal_steps=args.lr_anneal_steps, checkpoint_path=args.checkpoint_path, gradient_clipping=args.gradient_clipping, eval_data=None, eval_interval=args.eval_interval, ).run_loop() dist.destroy_process_group() def create_argparser(): defaults = dict() text_defaults = dict( wandb_token="", batch_size=16, cache_mode="no", checkpoint_path="checkpoints", class_cond=False, config="ll", config_name="QizhiPei/biot5-base-text2mol", dataset_path="dataset", diffusion_steps=2000, dropout=0.01, e2e_train="", ema_rate="0.9999", emb_scale_factor=1.0, eval_interval=2000, experiment="random", experiment_mode="lm", fp16_scale_growth=0.001, gradient_clipping=2.4, image_size=8, in_channel=16, learn_sigma=False, log_interval=1000, logits_mode=1, lr=0.00005, lr_anneal_steps=500000, microbatch=-1, modality="e2e-tgt", model_arch="transformer", noise_level=0.0, noise_schedule="sqrt", num_channels=128, num_heads=4, num_heads_upsample=-1, num_res_blocks=2, out_channel=16, padding_mode="pad", predict_xstart=True, preprocessing_num_workers=1, rescale_learned_sigmas=True, rescale_timesteps=True, resume_checkpoint="", save_interval=50000, schedule_sampler="uniform", seed=42, timestep_respacing="", training_mode="e2e", use_bert_tokenizer="no", use_checkpoint=False, use_fp16=False, use_kl=False, use_scale_shift_norm=True, weight_decay=0.0, model_in_channels=32, model_model_channels=128, model_dropout=0.01, model_hidden_size=1024, model_num_attention_heads=16, model_num_hidden_layers=12, dataset_name="", model_path="", ) defaults.update(model_and_diffusion_defaults()) defaults.update(text_defaults) parser = argparse.ArgumentParser() add_dict_to_argparser(parser, defaults) return parser if __name__ == "__main__": world_size = 1 mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True)