|
import torch
|
|
def get_config_phase1():
|
|
return {
|
|
"data_dir": "./data",
|
|
"clip_model_name": "openai/clip-vit-base-patch16",
|
|
"phi2_model_name": "microsoft/phi-2",
|
|
"train_batch_size": 2,
|
|
"val_batch_size": 1,
|
|
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
"epochs": 2,
|
|
"max_tokens": 20,
|
|
"clip_embed": 768,
|
|
"phi_embed": 2560,
|
|
"num_workers": 4,
|
|
"ckpts": "./ckpts"
|
|
}
|
|
|
|
def get_config_phase2():
|
|
return {
|
|
"data_dir": "./data",
|
|
"clip_model_name": "openai/clip-vit-base-patch16",
|
|
"phi2_model_name": "microsoft/phi-2",
|
|
"train_batch_size": 1,
|
|
"val_batch_size": 1,
|
|
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
"epochs": 10,
|
|
"max_tokens": 100,
|
|
"clip_embed": 768,
|
|
"phi_embed": 2560,
|
|
"num_workers": 0,
|
|
"ckpts": "./ckpts",
|
|
"vocab_size": 51200
|
|
} |