import random import shutil import time import torch # from torch.utils.tensorboard import SummaryWriter from utils.visualization import * from loguru import logger # def get_tensorboard_logger_from_args(tensorboard_dir, reset_version=False): # if reset_version: # shutil.rmtree(os.path.join(tensorboard_dir)) # return SummaryWriter(log_dir=tensorboard_dir) def get_optimizer_from_args(model, lr, weight_decay, **kwargs) -> torch.optim.Optimizer: return torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay) def get_lr_schedule(optimizer): return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def get_dir_from_args(root_dir, class_name, **kwargs): exp_name = f"{kwargs['dataset']}-k-{kwargs['k_shot']}" csv_dir = os.path.join(root_dir, 'csv') csv_path = os.path.join(csv_dir, f"{exp_name}-indx-{kwargs['experiment_indx']}.csv") model_dir = os.path.join(root_dir, exp_name, 'models') img_dir = os.path.join(root_dir, exp_name, 'imgs') logger_dir = os.path.join(root_dir, exp_name, 'logger', class_name) log_file_name = os.path.join(logger_dir, f'log_{time.strftime("%Y-%m-%d-%H-%I-%S", time.localtime(time.time()))}.log') model_name = f'{class_name}' os.makedirs(model_dir, exist_ok=True) os.makedirs(img_dir, exist_ok=True) os.makedirs(logger_dir, exist_ok=True) os.makedirs(csv_dir, exist_ok=True) logger.start(log_file_name) logger.info(f"===> Root dir for this experiment: {logger_dir}") return model_dir, img_dir, logger_dir, model_name, csv_path