MarcusSu1216 commited on
Commit
cf3daeb
1 Parent(s): b80e41a

Update onnx_export.py

Browse files
Files changed (1) hide show
  1. onnx_export.py +50 -10
onnx_export.py CHANGED
@@ -1,9 +1,51 @@
1
  import torch
 
 
2
  from onnxexport.model_onnx import SynthesizerTrn
3
  import utils
4
 
5
- def main(NetExport):
 
 
 
 
 
 
 
 
 
 
 
 
6
  path = "SoVits4.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  if NetExport:
8
  device = torch.device("cpu")
9
  hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
@@ -15,17 +57,15 @@ def main(NetExport):
15
  _ = SVCVITS.eval().to(device)
16
  for i in SVCVITS.parameters():
17
  i.requires_grad = False
18
-
19
- n_frame = 10
20
- test_hidden_unit = torch.rand(1, n_frame, 256)
21
- test_pitch = torch.rand(1, n_frame)
22
- test_mel2ph = torch.arange(0, n_frame, dtype=torch.int64)[None] # torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
23
- test_uv = torch.ones(1, n_frame, dtype=torch.float32)
24
- test_noise = torch.randn(1, 192, n_frame)
25
  test_sid = torch.LongTensor([0])
26
  input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
27
  output_names = ["audio", ]
28
-
29
  torch.onnx.export(SVCVITS,
30
  (
31
  test_hidden_unit.to(device),
@@ -51,4 +91,4 @@ def main(NetExport):
51
 
52
 
53
  if __name__ == '__main__':
54
- main(True)
 
1
  import torch
2
+ from torchaudio.models.wav2vec2.utils import import_fairseq_model
3
+ from fairseq import checkpoint_utils
4
  from onnxexport.model_onnx import SynthesizerTrn
5
  import utils
6
 
7
+ def get_hubert_model():
8
+ vec_path = "hubert/checkpoint_best_legacy_500.pt"
9
+ print("load model(s) from {}".format(vec_path))
10
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
11
+ [vec_path],
12
+ suffix="",
13
+ )
14
+ model = models[0]
15
+ model.eval()
16
+ return model
17
+
18
+
19
+ def main(HubertExport, NetExport):
20
  path = "SoVits4.0"
21
+
22
+ '''if HubertExport:
23
+ device = torch.device("cpu")
24
+ vec_path = "hubert/checkpoint_best_legacy_500.pt"
25
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
26
+ [vec_path],
27
+ suffix="",
28
+ )
29
+ original = models[0]
30
+ original.eval()
31
+ model = original
32
+ test_input = torch.rand(1, 1, 16000)
33
+ model(test_input)
34
+ torch.onnx.export(model,
35
+ test_input,
36
+ "hubert4.0.onnx",
37
+ export_params=True,
38
+ opset_version=16,
39
+ do_constant_folding=True,
40
+ input_names=['source'],
41
+ output_names=['embed'],
42
+ dynamic_axes={
43
+ 'source':
44
+ {
45
+ 2: "sample_length"
46
+ },
47
+ }
48
+ )'''
49
  if NetExport:
50
  device = torch.device("cpu")
51
  hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
 
57
  _ = SVCVITS.eval().to(device)
58
  for i in SVCVITS.parameters():
59
  i.requires_grad = False
60
+ test_hidden_unit = torch.rand(1, 10, 256)
61
+ test_pitch = torch.rand(1, 10)
62
+ test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
63
+ test_uv = torch.ones(1, 10, dtype=torch.float32)
64
+ test_noise = torch.randn(1, 192, 10)
 
 
65
  test_sid = torch.LongTensor([0])
66
  input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
67
  output_names = ["audio", ]
68
+ SVCVITS.eval()
69
  torch.onnx.export(SVCVITS,
70
  (
71
  test_hidden_unit.to(device),
 
91
 
92
 
93
  if __name__ == '__main__':
94
+ main(False, True)