Spaces:
Runtime error
Runtime error
File size: 4,091 Bytes
1f4e6d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
'''
author: wayn391@mastertones
'''
import datetime
import os
import time
import matplotlib.pyplot as plt
import torch
import yaml
from torch.utils.tensorboard import SummaryWriter
class Saver(object):
def __init__(
self,
args,
initial_global_step=-1):
self.expdir = args.env.expdir
self.sample_rate = args.data.sampling_rate
# cold start
self.global_step = initial_global_step
self.init_time = time.time()
self.last_time = time.time()
# makedirs
os.makedirs(self.expdir, exist_ok=True)
# path
self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
# ckpt
os.makedirs(self.expdir, exist_ok=True)
# writer
self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
# save config
path_config = os.path.join(self.expdir, 'config.yaml')
with open(path_config, "w") as out_config:
yaml.dump(dict(args), out_config)
def log_info(self, msg):
'''log method'''
if isinstance(msg, dict):
msg_list = []
for k, v in msg.items():
tmp_str = ''
if isinstance(v, int):
tmp_str = '{}: {:,}'.format(k, v)
else:
tmp_str = '{}: {}'.format(k, v)
msg_list.append(tmp_str)
msg_str = '\n'.join(msg_list)
else:
msg_str = msg
# dsplay
print(msg_str)
# save
with open(self.path_log_info, 'a') as fp:
fp.write(msg_str+'\n')
def log_value(self, dict):
for k, v in dict.items():
self.writer.add_scalar(k, v, self.global_step)
def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
spec = spec_cat[0]
if isinstance(spec, torch.Tensor):
spec = spec.cpu().numpy()
fig = plt.figure(figsize=(12, 9))
plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
plt.tight_layout()
self.writer.add_figure(name, fig, self.global_step)
def log_audio(self, dict):
for k, v in dict.items():
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
def get_interval_time(self, update=True):
cur_time = time.time()
time_interval = cur_time - self.last_time
if update:
self.last_time = cur_time
return time_interval
def get_total_time(self, to_str=True):
total_time = time.time() - self.init_time
if to_str:
total_time = str(datetime.timedelta(
seconds=total_time))[:-5]
return total_time
def save_model(
self,
model,
optimizer,
name='model',
postfix='',
to_json=False):
# path
if postfix:
postfix = '_' + postfix
path_pt = os.path.join(
self.expdir , name+postfix+'.pt')
# check
print(' [*] model checkpoint saved: {}'.format(path_pt))
# save
if optimizer is not None:
torch.save({
'global_step': self.global_step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}, path_pt)
else:
torch.save({
'global_step': self.global_step,
'model': model.state_dict()}, path_pt)
def delete_model(self, name='model', postfix=''):
# path
if postfix:
postfix = '_' + postfix
path_pt = os.path.join(
self.expdir , name+postfix+'.pt')
# delete
if os.path.exists(path_pt):
os.remove(path_pt)
print(' [*] model checkpoint deleted: {}'.format(path_pt))
def global_step_increment(self):
self.global_step += 1
|