Spaces:
Runtime error
Runtime error
from tasks.train.instruction_data import * | |
# ========================= data ========================== | |
# train_corpus = "videochat2_instruction" | |
train_corpus = "videochat2_instruction_full" | |
train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation | |
test_file = dict() | |
test_types = [] | |
num_workers = 8 | |
save_steps=10000 | |
ckpt_steps=1000 | |
stop_key = None | |
deepspeed=False | |
highres=None | |
# ========================= input ========================== | |
num_frames = 16 | |
num_frames_test = 1 | |
batch_size = 1 | |
gradient_accumulation_steps=16 | |
max_txt_l = 512 | |
max_train_steps=None | |
pre_text = False | |
gradient_checkpointing=False | |
inputs = dict( | |
image_res=336, | |
video_input=dict( | |
num_frames="${num_frames}", | |
sample_type="rand", | |
num_frames_test="${num_frames_test}", | |
sample_type_test="middle", | |
random_aug=False, | |
), | |
max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), | |
batch_size=dict(image="${batch_size}", video="${batch_size}"), | |
batch_size_test=dict(image="${batch_size}", video="${batch_size}"), | |
) | |
model = dict( | |
repo_id="llava-hf/llava-1.5-7b-hf", | |
pretrained_path=None, | |
load_from_origin=False, | |
origin_vision="", | |
origin_llm="", | |
vision_encoder=dict( | |
name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model | |
), | |
torch_dtype='bfloat16', | |
freeze_projector=False, | |
freeze_lm=True, | |
freeze_vision_tower=True, | |
lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma | |
use_lora=True, | |
lora_r=128, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
num_frames="${num_frames}", | |
pooling_method='avg', | |
use_pooling=True, | |
frame_shape=(24,24), | |
pooling_shape=(16,8,8), | |
) | |
preprocess = dict( | |
system="", | |
mm_alone=True, | |
image_token_index=64002, | |
random_shuffle=True, | |
add_second_msg=True, | |
roles=['<|im_start|>user\n', '<|im_start|>assistant\n'], | |
end_signal=('<|im_end|>\n', '<|im_end|>\n'), | |
begin_signal='', | |
dataset_image_placeholder='<Image></Image>', | |
dataset_video_placeholder='<Video></Video>', | |
max_txt_l = "${max_txt_l}", | |
ignore_index=-100, # same as torch softmax ignore index | |
center_pad=False, | |
longest_edge=762, | |
shortest_edge=336, | |
clip_transform=False, | |
num_frames="${num_frames}", | |
) | |
optimizer = dict( | |
opt="adamW", | |
lr=2e-5, | |
opt_betas=[0.9, 0.999], # default | |
weight_decay=0.02, | |
max_grad_norm=-1, # requires a positive float, use -1 to disable | |
# use a different lr for some modules, e.g., larger lr for new modules | |
different_lr=dict(enable=False, module_names=[], lr=1e-3), | |
) | |
# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) | |
# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) | |
scheduler = dict( | |
is_videochat2_custom=False, | |
sched="cosine", | |
epochs=2, | |
warmup_ratio=0.2, | |
min_lr_multi=0.25) | |
evaluate = False | |
deep_fusion = False | |
evaluation = dict( | |
eval_frame_ensemble="concat", # [concat, max, mean, lse] | |
eval_x_only=False, | |
k_test=128, | |
eval_offload=True, # offload gpu tensors to cpu to save memory. | |
) | |
fp16 = True | |
gradient_checkpointing = True | |
# ========================= wandb ========================== | |
wandb = dict( | |
enable=False, | |
entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init | |
project="videochat2", # setup in your command line | |
) | |
dist_url = "env://" | |
device = "cuda" | |
mode = "it" | |
# ========================= others ========================== | |
output_dir = None # output dir | |
resume = False # if True, load optimizer and scheduler states as well | |
debug = False | |
log_freq = 5 | |
metric_window_size=10 # window size for metric | |
seed = 42 | |
report_to='tensorboard' | |
save_latest = True | |
auto_resume = True | |
pretrained_path = "" # path to pretrained model weights, for resume only? | |