Spaces:
Sleeping
Sleeping
File size: 6,207 Bytes
1040e55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import math
import random
import torch
import numpy as np
from icecream import ic
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
ARGS = None
def set_args(args):
global ARGS
ARGS = args
def get_args():
return ARGS
TOKENIZER = None
def set_tokenizer(tokenizer):
global TOKENIZER
TOKENIZER = tokenizer
def get_tokenizer():
return TOKENIZER
from torch import distributed as dist
class worker_init:
def __init__(self, epoch_id):
self.epoch_id = epoch_id
def _worker_init_fn(self, worker_id):
random.seed(worker_id + self.epoch_id*1e4 + dist.get_rank()*1e8)
def batchify(batch):
# collate_fn
video = [data["video"] if data["video"] is not None else None for data in batch]
if all([img is None for img in video]):
video = None
else:
video = torch.cat([img for img in video if img is not None], dim=0)
num_videos_per_sample = torch.LongTensor([data["video"].size(0) if data['video'] is not None else 0 for data in batch])
num_images_per_sample = torch.LongTensor([0 for data in batch])
text = torch.stack([torch.LongTensor(data["text"]['input_ids']) for data in batch], dim=0)
non_padding_mask = torch.stack([torch.LongTensor(data["text"]['non_padding_mask']) for data in batch], dim=0)
non_media_mask = torch.stack([torch.LongTensor(data["text"]['non_media_mask']) for data in batch], dim=0)
prompt_mask = torch.stack([torch.LongTensor(data["text"]['prompt_mask']) for data in batch], dim=0)
videopaths = [data["videopath"] for data in batch]
captions = [data["caption"] for data in batch]
output_batch = {
"pixel_values": None,
"video_pixel_values": video,
"input_ids": text.long(),
"labels": text.long().clone(),
"num_images": num_images_per_sample.long(),
"num_videos": num_videos_per_sample.long(),
"non_padding_mask": non_padding_mask.long(),
"non_media_mask": non_media_mask.long(),
"prompt_mask": prompt_mask.long(),
"videopaths": videopaths,
"captions": captions,
}
return output_batch
def get_param_groups(modules,
no_weight_decay_cond,
scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
wd_no_scale_lr = []
wd_scale_lr = []
no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules:
for name, param in module.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
param_groups = []
if len(wd_no_scale_lr):
param_groups.append(
{'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append(
{'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr,
'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append(
{'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_cosine_schedule_with_warmup(
optimizer, lr, min_lr, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
delta_min_lr = (lr-min_lr)/lr # 0.95
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return (1-delta_min_lr) + delta_min_lr * float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / \
float(max(1, num_training_steps - num_warmup_steps))
return delta_min_lr + (1-delta_min_lr) * max(0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
from torch.optim.lr_scheduler import LambdaLR
return LambdaLR(optimizer, lr_lambda, last_epoch) |