jimmycv07's picture
first commit
1de8821
raw
history blame
1.15 kB
import os
import numpy as np
import sys
import json
def read_text_lines(filepath):
with open(filepath, 'r') as f:
lines = f.readlines()
lines = [l.rstrip() for l in lines]
return lines
def check_path(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing
def save_command(save_path, filename='command_train.txt'):
check_path(save_path)
command = sys.argv
save_file = os.path.join(save_path, filename)
# Save all training commands when resuming training
with open(save_file, 'a') as f:
f.write(' '.join(command))
f.write('\n\n')
def save_args(args, filename='args.json'):
args_dict = vars(args)
check_path(args.checkpoint_dir)
save_path = os.path.join(args.checkpoint_dir, filename)
# Save all training args when resuming training
with open(save_path, 'a') as f:
json.dump(args_dict, f, indent=4, sort_keys=False)
f.write('\n\n')
def int_list(s):
"""Convert string to int list"""
return [int(x) for x in s.split(',')]