smjain commited on
Commit
ef5c9b8
1 Parent(s): f8726c0

Upload 5 files

Browse files
Files changed (5) hide show
  1. lib/data_utils.py +7 -12
  2. lib/losses.py +1 -0
  3. lib/mel_processing.py +4 -6
  4. lib/process_ckpt.py +113 -126
  5. lib/utils.py +33 -40
lib/data_utils.py CHANGED
@@ -1,15 +1,10 @@
1
- import os
2
- import traceback
3
- import logging
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
  import numpy as np
8
  import torch
9
  import torch.utils.data
10
 
11
- from infer.lib.train.mel_processing import spectrogram_torch
12
- from infer.lib.train.utils import load_filepaths_and_text, load_wav_to_torch
13
 
14
 
15
  class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
@@ -43,7 +38,7 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
43
  for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
44
  if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
45
  audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
46
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
47
  self.audiopaths_and_text = audiopaths_and_text_new
48
  self.lengths = lengths
49
 
@@ -113,7 +108,7 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
113
  try:
114
  spec = torch.load(spec_filename)
115
  except:
116
- logger.warning("%s %s", spec_filename, traceback.format_exc())
117
  spec = spectrogram_torch(
118
  audio_norm,
119
  self.filter_length,
@@ -251,7 +246,7 @@ class TextAudioLoader(torch.utils.data.Dataset):
251
  for audiopath, text, dv in self.audiopaths_and_text:
252
  if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
253
  audiopaths_and_text_new.append([audiopath, text, dv])
254
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
255
  self.audiopaths_and_text = audiopaths_and_text_new
256
  self.lengths = lengths
257
 
@@ -305,7 +300,7 @@ class TextAudioLoader(torch.utils.data.Dataset):
305
  try:
306
  spec = torch.load(spec_filename)
307
  except:
308
- logger.warning("%s %s", spec_filename, traceback.format_exc())
309
  spec = spectrogram_torch(
310
  audio_norm,
311
  self.filter_length,
 
1
+ import os, traceback
 
 
 
 
 
2
  import numpy as np
3
  import torch
4
  import torch.utils.data
5
 
6
+ from mel_processing import spectrogram_torch
7
+ from utils import load_wav_to_torch, load_filepaths_and_text
8
 
9
 
10
  class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
 
38
  for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
39
  if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
40
  audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
41
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
42
  self.audiopaths_and_text = audiopaths_and_text_new
43
  self.lengths = lengths
44
 
 
108
  try:
109
  spec = torch.load(spec_filename)
110
  except:
111
+ print(spec_filename, traceback.format_exc())
112
  spec = spectrogram_torch(
113
  audio_norm,
114
  self.filter_length,
 
246
  for audiopath, text, dv in self.audiopaths_and_text:
247
  if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
248
  audiopaths_and_text_new.append([audiopath, text, dv])
249
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
250
  self.audiopaths_and_text = audiopaths_and_text_new
251
  self.lengths = lengths
252
 
 
300
  try:
301
  spec = torch.load(spec_filename)
302
  except:
303
+ print(spec_filename, traceback.format_exc())
304
  spec = spectrogram_torch(
305
  audio_norm,
306
  self.filter_length,
lib/losses.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
 
3
 
4
  def feature_loss(fmap_r, fmap_g):
 
1
  import torch
2
+ from torch.nn import functional as F
3
 
4
 
5
  def feature_loss(fmap_r, fmap_g):
lib/mel_processing.py CHANGED
@@ -1,9 +1,7 @@
1
  import torch
2
  import torch.utils.data
3
  from librosa.filters import mel as librosa_mel_fn
4
- import logging
5
 
6
- logger = logging.getLogger(__name__)
7
 
8
  MAX_WAV_VALUE = 32768.0
9
 
@@ -53,10 +51,10 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
53
  :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
54
  """
55
  # Validation
56
- if torch.min(y) < -1.07:
57
- logger.debug("min value is %s", str(torch.min(y)))
58
- if torch.max(y) > 1.07:
59
- logger.debug("max value is %s", str(torch.max(y)))
60
 
61
  # Window - Cache if needed
62
  global hann_window
 
1
  import torch
2
  import torch.utils.data
3
  from librosa.filters import mel as librosa_mel_fn
 
4
 
 
5
 
6
  MAX_WAV_VALUE = 32768.0
7
 
 
51
  :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
52
  """
53
  # Validation
54
+ if torch.min(y) < -1.0:
55
+ print("min value is ", torch.min(y))
56
+ if torch.max(y) > 1.0:
57
+ print("max value is ", torch.max(y))
58
 
59
  # Window - Cache if needed
60
  global hann_window
lib/process_ckpt.py CHANGED
@@ -1,16 +1,8 @@
1
- import os
2
- import sys
3
- import traceback
4
  from collections import OrderedDict
5
 
6
- import torch
7
 
8
- from i18n.i18n import I18nAuto
9
-
10
- i18n = I18nAuto()
11
-
12
-
13
- def savee(ckpt, sr, if_f0, name, epoch, version, hps):
14
  try:
15
  opt = OrderedDict()
16
  opt["weight"] = {}
@@ -18,31 +10,73 @@ def savee(ckpt, sr, if_f0, name, epoch, version, hps):
18
  if "enc_q" in key:
19
  continue
20
  opt["weight"][key] = ckpt[key].half()
21
- opt["config"] = [
22
- hps.data.filter_length // 2 + 1,
23
- 32,
24
- hps.model.inter_channels,
25
- hps.model.hidden_channels,
26
- hps.model.filter_channels,
27
- hps.model.n_heads,
28
- hps.model.n_layers,
29
- hps.model.kernel_size,
30
- hps.model.p_dropout,
31
- hps.model.resblock,
32
- hps.model.resblock_kernel_sizes,
33
- hps.model.resblock_dilation_sizes,
34
- hps.model.upsample_rates,
35
- hps.model.upsample_initial_channel,
36
- hps.model.upsample_kernel_sizes,
37
- hps.model.spk_embed_dim,
38
- hps.model.gin_channels,
39
- hps.data.sampling_rate,
40
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  opt["info"] = "%sepoch" % epoch
42
  opt["sr"] = sr
43
  opt["f0"] = if_f0
44
- opt["version"] = version
45
- torch.save(opt, "assets/weights/%s.pth" % name)
46
  return "Success."
47
  except:
48
  return traceback.format_exc()
@@ -51,17 +85,16 @@ def savee(ckpt, sr, if_f0, name, epoch, version, hps):
51
  def show_info(path):
52
  try:
53
  a = torch.load(path, map_location="cpu")
54
- return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
55
  a.get("info", "None"),
56
  a.get("sr", "None"),
57
  a.get("f0", "None"),
58
- a.get("version", "None"),
59
  )
60
  except:
61
  return traceback.format_exc()
62
 
63
 
64
- def extract_small_model(path, name, sr, if_f0, info, version):
65
  try:
66
  ckpt = torch.load(path, map_location="cpu")
67
  if "model" in ckpt:
@@ -94,98 +127,53 @@ def extract_small_model(path, name, sr, if_f0, info, version):
94
  40000,
95
  ]
96
  elif sr == "48k":
97
- if version == "v1":
98
- opt["config"] = [
99
- 1025,
100
- 32,
101
- 192,
102
- 192,
103
- 768,
104
- 2,
105
- 6,
106
- 3,
107
- 0,
108
- "1",
109
- [3, 7, 11],
110
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
111
- [10, 6, 2, 2, 2],
112
- 512,
113
- [16, 16, 4, 4, 4],
114
- 109,
115
- 256,
116
- 48000,
117
- ]
118
- else:
119
- opt["config"] = [
120
- 1025,
121
- 32,
122
- 192,
123
- 192,
124
- 768,
125
- 2,
126
- 6,
127
- 3,
128
- 0,
129
- "1",
130
- [3, 7, 11],
131
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
132
- [12, 10, 2, 2],
133
- 512,
134
- [24, 20, 4, 4],
135
- 109,
136
- 256,
137
- 48000,
138
- ]
139
  elif sr == "32k":
140
- if version == "v1":
141
- opt["config"] = [
142
- 513,
143
- 32,
144
- 192,
145
- 192,
146
- 768,
147
- 2,
148
- 6,
149
- 3,
150
- 0,
151
- "1",
152
- [3, 7, 11],
153
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
154
- [10, 4, 2, 2, 2],
155
- 512,
156
- [16, 16, 4, 4, 4],
157
- 109,
158
- 256,
159
- 32000,
160
- ]
161
- else:
162
- opt["config"] = [
163
- 513,
164
- 32,
165
- 192,
166
- 192,
167
- 768,
168
- 2,
169
- 6,
170
- 3,
171
- 0,
172
- "1",
173
- [3, 7, 11],
174
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
175
- [10, 8, 2, 2],
176
- 512,
177
- [20, 16, 4, 4],
178
- 109,
179
- 256,
180
- 32000,
181
- ]
182
  if info == "":
183
  info = "Extracted model."
184
  opt["info"] = info
185
- opt["version"] = version
186
  opt["sr"] = sr
187
  opt["f0"] = int(if_f0)
188
- torch.save(opt, "assets/weights/%s.pth" % name)
189
  return "Success."
190
  except:
191
  return traceback.format_exc()
@@ -197,13 +185,13 @@ def change_info(path, info, name):
197
  ckpt["info"] = info
198
  if name == "":
199
  name = os.path.basename(path)
200
- torch.save(ckpt, "assets/weights/%s" % name)
201
  return "Success."
202
  except:
203
  return traceback.format_exc()
204
 
205
 
206
- def merge(path1, path2, alpha1, sr, f0, info, name, version):
207
  try:
208
 
209
  def extract(ckpt):
@@ -252,10 +240,9 @@ def merge(path1, path2, alpha1, sr, f0, info, name, version):
252
  elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
253
  """
254
  opt["sr"] = sr
255
- opt["f0"] = 1 if f0 == i18n("是") else 0
256
- opt["version"] = version
257
  opt["info"] = info
258
- torch.save(opt, "assets/weights/%s.pth" % name)
259
  return "Success."
260
  except:
261
  return traceback.format_exc()
 
1
+ import torch, traceback, os, pdb
 
 
2
  from collections import OrderedDict
3
 
 
4
 
5
+ def savee(ckpt, sr, if_f0, name, epoch):
 
 
 
 
 
6
  try:
7
  opt = OrderedDict()
8
  opt["weight"] = {}
 
10
  if "enc_q" in key:
11
  continue
12
  opt["weight"][key] = ckpt[key].half()
13
+ if sr == "40k":
14
+ opt["config"] = [
15
+ 1025,
16
+ 32,
17
+ 192,
18
+ 192,
19
+ 768,
20
+ 2,
21
+ 6,
22
+ 3,
23
+ 0,
24
+ "1",
25
+ [3, 7, 11],
26
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
27
+ [10, 10, 2, 2],
28
+ 512,
29
+ [16, 16, 4, 4],
30
+ 109,
31
+ 256,
32
+ 40000,
33
+ ]
34
+ elif sr == "48k":
35
+ opt["config"] = [
36
+ 1025,
37
+ 32,
38
+ 192,
39
+ 192,
40
+ 768,
41
+ 2,
42
+ 6,
43
+ 3,
44
+ 0,
45
+ "1",
46
+ [3, 7, 11],
47
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
48
+ [10, 6, 2, 2, 2],
49
+ 512,
50
+ [16, 16, 4, 4, 4],
51
+ 109,
52
+ 256,
53
+ 48000,
54
+ ]
55
+ elif sr == "32k":
56
+ opt["config"] = [
57
+ 513,
58
+ 32,
59
+ 192,
60
+ 192,
61
+ 768,
62
+ 2,
63
+ 6,
64
+ 3,
65
+ 0,
66
+ "1",
67
+ [3, 7, 11],
68
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
69
+ [10, 4, 2, 2, 2],
70
+ 512,
71
+ [16, 16, 4, 4, 4],
72
+ 109,
73
+ 256,
74
+ 32000,
75
+ ]
76
  opt["info"] = "%sepoch" % epoch
77
  opt["sr"] = sr
78
  opt["f0"] = if_f0
79
+ torch.save(opt, "weights/%s.pth" % name)
 
80
  return "Success."
81
  except:
82
  return traceback.format_exc()
 
85
  def show_info(path):
86
  try:
87
  a = torch.load(path, map_location="cpu")
88
+ return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s" % (
89
  a.get("info", "None"),
90
  a.get("sr", "None"),
91
  a.get("f0", "None"),
 
92
  )
93
  except:
94
  return traceback.format_exc()
95
 
96
 
97
+ def extract_small_model(path, name, sr, if_f0, info):
98
  try:
99
  ckpt = torch.load(path, map_location="cpu")
100
  if "model" in ckpt:
 
127
  40000,
128
  ]
129
  elif sr == "48k":
130
+ opt["config"] = [
131
+ 1025,
132
+ 32,
133
+ 192,
134
+ 192,
135
+ 768,
136
+ 2,
137
+ 6,
138
+ 3,
139
+ 0,
140
+ "1",
141
+ [3, 7, 11],
142
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
143
+ [10, 6, 2, 2, 2],
144
+ 512,
145
+ [16, 16, 4, 4, 4],
146
+ 109,
147
+ 256,
148
+ 48000,
149
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  elif sr == "32k":
151
+ opt["config"] = [
152
+ 513,
153
+ 32,
154
+ 192,
155
+ 192,
156
+ 768,
157
+ 2,
158
+ 6,
159
+ 3,
160
+ 0,
161
+ "1",
162
+ [3, 7, 11],
163
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
164
+ [10, 4, 2, 2, 2],
165
+ 512,
166
+ [16, 16, 4, 4, 4],
167
+ 109,
168
+ 256,
169
+ 32000,
170
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  if info == "":
172
  info = "Extracted model."
173
  opt["info"] = info
 
174
  opt["sr"] = sr
175
  opt["f0"] = int(if_f0)
176
+ torch.save(opt, "weights/%s.pth" % name)
177
  return "Success."
178
  except:
179
  return traceback.format_exc()
 
185
  ckpt["info"] = info
186
  if name == "":
187
  name = os.path.basename(path)
188
+ torch.save(ckpt, "weights/%s" % name)
189
  return "Success."
190
  except:
191
  return traceback.format_exc()
192
 
193
 
194
+ def merge(path1, path2, alpha1, sr, f0, info, name):
195
  try:
196
 
197
  def extract(ckpt):
 
240
  elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
241
  """
242
  opt["sr"] = sr
243
+ opt["f0"] = 1 if f0 == "是" else 0
 
244
  opt["info"] = info
245
+ torch.save(opt, "weights/%s.pth" % name)
246
  return "Success."
247
  except:
248
  return traceback.format_exc()
lib/utils.py CHANGED
@@ -1,15 +1,13 @@
1
- import argparse
2
  import glob
3
- import json
 
4
  import logging
5
- import os
6
  import subprocess
7
- import sys
8
- import shutil
9
-
10
  import numpy as np
11
- import torch
12
  from scipy.io.wavfile import read
 
13
 
14
  MATPLOTLIB_FLAG = False
15
 
@@ -33,25 +31,22 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
33
  try:
34
  new_state_dict[k] = saved_state_dict[k]
35
  if saved_state_dict[k].shape != state_dict[k].shape:
36
- logger.warning(
37
- "shape-%s-mismatch. need: %s, get: %s",
38
- k,
39
- state_dict[k].shape,
40
- saved_state_dict[k].shape,
41
  ) #
42
  raise KeyError
43
  except:
44
  # logger.info(traceback.format_exc())
45
- logger.info("%s is not in the checkpoint", k) # pretrain缺失的
46
  new_state_dict[k] = v # 模型自带的随机值
47
  if hasattr(model, "module"):
48
  model.module.load_state_dict(new_state_dict, strict=False)
49
  else:
50
  model.load_state_dict(new_state_dict, strict=False)
51
- return model
52
 
53
  go(combd, "combd")
54
- model = go(sbd, "sbd")
55
  #############
56
  logger.info("Loaded model weights")
57
 
@@ -111,16 +106,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
111
  try:
112
  new_state_dict[k] = saved_state_dict[k]
113
  if saved_state_dict[k].shape != state_dict[k].shape:
114
- logger.warning(
115
- "shape-%s-mismatch|need-%s|get-%s",
116
- k,
117
- state_dict[k].shape,
118
- saved_state_dict[k].shape,
119
  ) #
120
  raise KeyError
121
  except:
122
  # logger.info(traceback.format_exc())
123
- logger.info("%s is not in the checkpoint", k) # pretrain缺失的
124
  new_state_dict[k] = v # 模型自带的随机值
125
  if hasattr(model, "module"):
126
  model.module.load_state_dict(new_state_dict, strict=False)
@@ -211,7 +204,7 @@ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
211
  f_list = glob.glob(os.path.join(dir_path, regex))
212
  f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
213
  x = f_list[-1]
214
- logger.debug(x)
215
  return x
216
 
217
 
@@ -291,8 +284,8 @@ def get_hparams(init=True):
291
  bs done
292
  pretrainG、pretrainD done
293
  卡号:os.en["CUDA_VISIBLE_DEVICES"] done
294
- if_latest done
295
- 模型:if_f0 done
296
  采样率:自动选择config done
297
  是否缓存数据集进GPU:if_cache_data_in_gpu done
298
 
@@ -301,6 +294,7 @@ def get_hparams(init=True):
301
  -c不要了
302
  """
303
  parser = argparse.ArgumentParser()
 
304
  parser.add_argument(
305
  "-se",
306
  "--save_every_epoch",
@@ -327,16 +321,6 @@ def get_hparams(init=True):
327
  parser.add_argument(
328
  "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
329
  )
330
- parser.add_argument(
331
- "-sw",
332
- "--save_every_weights",
333
- type=str,
334
- default="0",
335
- help="save the extracted model in weights directory when saving checkpoints",
336
- )
337
- parser.add_argument(
338
- "-v", "--version", type=str, required=True, help="model version"
339
- )
340
  parser.add_argument(
341
  "-f0",
342
  "--if_f0",
@@ -363,9 +347,20 @@ def get_hparams(init=True):
363
  name = args.experiment_dir
364
  experiment_dir = os.path.join("./logs", args.experiment_dir)
365
 
 
 
 
 
366
  config_save_path = os.path.join(experiment_dir, "config.json")
367
- with open(config_save_path, "r") as f:
368
- config = json.load(f)
 
 
 
 
 
 
 
369
 
370
  hparams = HParams(**config)
371
  hparams.model_dir = hparams.experiment_dir = experiment_dir
@@ -374,13 +369,11 @@ def get_hparams(init=True):
374
  hparams.total_epoch = args.total_epoch
375
  hparams.pretrainG = args.pretrainG
376
  hparams.pretrainD = args.pretrainD
377
- hparams.version = args.version
378
  hparams.gpus = args.gpus
379
  hparams.train.batch_size = args.batch_size
380
  hparams.sample_rate = args.sample_rate
381
  hparams.if_f0 = args.if_f0
382
  hparams.if_latest = args.if_latest
383
- hparams.save_every_weights = args.save_every_weights
384
  hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
385
  hparams.data.training_files = "%s/filelist.txt" % experiment_dir
386
  return hparams
@@ -409,7 +402,7 @@ def get_hparams_from_file(config_path):
409
  def check_git_hash(model_dir):
410
  source_dir = os.path.dirname(os.path.realpath(__file__))
411
  if not os.path.exists(os.path.join(source_dir, ".git")):
412
- logger.warning(
413
  "{} is not a git repository, therefore hash value comparison will be ignored.".format(
414
  source_dir
415
  )
@@ -422,7 +415,7 @@ def check_git_hash(model_dir):
422
  if os.path.exists(path):
423
  saved_hash = open(path).read()
424
  if saved_hash != cur_hash:
425
- logger.warning(
426
  "git hash values are different. {}(saved) != {}(current)".format(
427
  saved_hash[:8], cur_hash[:8]
428
  )
 
1
+ import os, traceback
2
  import glob
3
+ import sys
4
+ import argparse
5
  import logging
6
+ import json
7
  import subprocess
 
 
 
8
  import numpy as np
 
9
  from scipy.io.wavfile import read
10
+ import torch
11
 
12
  MATPLOTLIB_FLAG = False
13
 
 
31
  try:
32
  new_state_dict[k] = saved_state_dict[k]
33
  if saved_state_dict[k].shape != state_dict[k].shape:
34
+ print(
35
+ "shape-%s-mismatch|need-%s|get-%s"
36
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
 
 
37
  ) #
38
  raise KeyError
39
  except:
40
  # logger.info(traceback.format_exc())
41
+ logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
42
  new_state_dict[k] = v # 模型自带的随机值
43
  if hasattr(model, "module"):
44
  model.module.load_state_dict(new_state_dict, strict=False)
45
  else:
46
  model.load_state_dict(new_state_dict, strict=False)
 
47
 
48
  go(combd, "combd")
49
+ go(sbd, "sbd")
50
  #############
51
  logger.info("Loaded model weights")
52
 
 
106
  try:
107
  new_state_dict[k] = saved_state_dict[k]
108
  if saved_state_dict[k].shape != state_dict[k].shape:
109
+ print(
110
+ "shape-%s-mismatch|need-%s|get-%s"
111
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
 
 
112
  ) #
113
  raise KeyError
114
  except:
115
  # logger.info(traceback.format_exc())
116
+ logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
117
  new_state_dict[k] = v # 模型自带的随机值
118
  if hasattr(model, "module"):
119
  model.module.load_state_dict(new_state_dict, strict=False)
 
204
  f_list = glob.glob(os.path.join(dir_path, regex))
205
  f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
206
  x = f_list[-1]
207
+ print(x)
208
  return x
209
 
210
 
 
284
  bs done
285
  pretrainG、pretrainD done
286
  卡号:os.en["CUDA_VISIBLE_DEVICES"] done
287
+ if_latest todo
288
+ 模型:if_f0 todo
289
  采样率:自动选择config done
290
  是否缓存数据集进GPU:if_cache_data_in_gpu done
291
 
 
294
  -c不要了
295
  """
296
  parser = argparse.ArgumentParser()
297
+ # parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
298
  parser.add_argument(
299
  "-se",
300
  "--save_every_epoch",
 
321
  parser.add_argument(
322
  "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
323
  )
 
 
 
 
 
 
 
 
 
 
324
  parser.add_argument(
325
  "-f0",
326
  "--if_f0",
 
347
  name = args.experiment_dir
348
  experiment_dir = os.path.join("./logs", args.experiment_dir)
349
 
350
+ if not os.path.exists(experiment_dir):
351
+ os.makedirs(experiment_dir)
352
+
353
+ config_path = "configs/%s.json" % args.sample_rate
354
  config_save_path = os.path.join(experiment_dir, "config.json")
355
+ if init:
356
+ with open(config_path, "r") as f:
357
+ data = f.read()
358
+ with open(config_save_path, "w") as f:
359
+ f.write(data)
360
+ else:
361
+ with open(config_save_path, "r") as f:
362
+ data = f.read()
363
+ config = json.loads(data)
364
 
365
  hparams = HParams(**config)
366
  hparams.model_dir = hparams.experiment_dir = experiment_dir
 
369
  hparams.total_epoch = args.total_epoch
370
  hparams.pretrainG = args.pretrainG
371
  hparams.pretrainD = args.pretrainD
 
372
  hparams.gpus = args.gpus
373
  hparams.train.batch_size = args.batch_size
374
  hparams.sample_rate = args.sample_rate
375
  hparams.if_f0 = args.if_f0
376
  hparams.if_latest = args.if_latest
 
377
  hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
378
  hparams.data.training_files = "%s/filelist.txt" % experiment_dir
379
  return hparams
 
402
  def check_git_hash(model_dir):
403
  source_dir = os.path.dirname(os.path.realpath(__file__))
404
  if not os.path.exists(os.path.join(source_dir, ".git")):
405
+ logger.warn(
406
  "{} is not a git repository, therefore hash value comparison will be ignored.".format(
407
  source_dir
408
  )
 
415
  if os.path.exists(path):
416
  saved_hash = open(path).read()
417
  if saved_hash != cur_hash:
418
+ logger.warn(
419
  "git hash values are different. {}(saved) != {}(current)".format(
420
  saved_hash[:8], cur_hash[:8]
421
  )