Plachta commited on
Commit
daa5921
·
1 Parent(s): fcb799f

Delete Libtorch C++ Infer

Browse files
Libtorch C++ Infer/VITS-LibTorch.cpp DELETED
@@ -1,121 +0,0 @@
1
- #include <iostream>
2
- #include <torch/torch.h>
3
- #include <torch/script.h>
4
- #include <string>
5
- #include <vector>
6
- #include <locale>
7
- #include <codecvt>
8
- #include <direct.h>
9
- #include <fstream>
10
- typedef int64_t int64;
11
- namespace Shirakana {
12
-
13
- struct WavHead {
14
- char RIFF[4];
15
- long int size0;
16
- char WAVE[4];
17
- char FMT[4];
18
- long int size1;
19
- short int fmttag;
20
- short int channel;
21
- long int samplespersec;
22
- long int bytepersec;
23
- short int blockalign;
24
- short int bitpersamples;
25
- char DATA[4];
26
- long int size2;
27
- };
28
-
29
- int conArr2Wav(int64 size, int16_t* input, const char* filename) {
30
- WavHead head = { {'R','I','F','F'},0,{'W','A','V','E'},{'f','m','t',' '},16,
31
- 1,1,22050,22050 * 2,2,16,{'d','a','t','a'},
32
- 0 };
33
- head.size0 = size * 2 + 36;
34
- head.size2 = size * 2;
35
- std::ofstream ocout;
36
- char* outputData = (char*)input;
37
- ocout.open(filename, std::ios::out | std::ios::binary);
38
- ocout.write((char*)&head, 44);
39
- ocout.write(outputData, (int32_t)(size * 2));
40
- ocout.close();
41
- return 0;
42
- }
43
-
44
- inline std::wstring to_wide_string(const std::string& input)
45
- {
46
- std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
47
- return converter.from_bytes(input);
48
- }
49
-
50
- inline std::string to_byte_string(const std::wstring& input)
51
- {
52
- std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
53
- return converter.to_bytes(input);
54
- }
55
- }
56
-
57
- #define val const auto
58
- int main()
59
- {
60
- torch::jit::Module Vits;
61
- std::string buffer;
62
- std::vector<int64> text;
63
- std::vector<int16_t> data;
64
- while(true)
65
- {
66
- while (true)
67
- {
68
- std::cin >> buffer;
69
- if (buffer == "end")
70
- return 0;
71
- if(buffer == "model")
72
- {
73
- std::cin >> buffer;
74
- Vits = torch::jit::load(buffer);
75
- continue;
76
- }
77
- if (buffer == "endinfer")
78
- {
79
- Shirakana::conArr2Wav(data.size(), data.data(), "temp\\tmp.wav");
80
- data.clear();
81
- std::cout << "endofinfe";
82
- continue;
83
- }
84
- if (buffer == "line")
85
- {
86
- std::cin >> buffer;
87
- while (buffer.find("endline")==std::string::npos)
88
- {
89
- text.push_back(std::atoi(buffer.c_str()));
90
- std::cin >> buffer;
91
- }
92
- val InputTensor = torch::from_blob(text.data(), { 1,static_cast<int64>(text.size()) }, torch::kInt64);
93
- std::array<int64, 1> TextLength{ static_cast<int64>(text.size()) };
94
- val InputTensor_length = torch::from_blob(TextLength.data(), { 1 }, torch::kInt64);
95
- std::vector<torch::IValue> inputs;
96
- inputs.push_back(InputTensor);
97
- inputs.push_back(InputTensor_length);
98
- if (buffer.length() > 7)
99
- {
100
- std::array<int64, 1> speakerIndex{ (int64)atoi(buffer.substr(7).c_str()) };
101
- inputs.push_back(torch::from_blob(speakerIndex.data(), { 1 }, torch::kLong));
102
- }
103
- val output = Vits.forward(inputs).toTuple()->elements()[0].toTensor().multiply(32276.0F);
104
- val outputSize = output.sizes().at(2);
105
- val floatOutput = output.data_ptr<float>();
106
- int16_t* outputTmp = (int16_t*)malloc(sizeof(float) * outputSize);
107
- if (outputTmp == nullptr) {
108
- throw std::exception("内存不足");
109
- }
110
- for (int i = 0; i < outputSize; i++) {
111
- *(outputTmp + i) = (int16_t) * (floatOutput + i);
112
- }
113
- data.insert(data.end(), outputTmp, outputTmp+outputSize);
114
- free(outputTmp);
115
- text.clear();
116
- std::cout << "endofline";
117
- }
118
- }
119
- }
120
- //model S:\VSGIT\ShirakanaTTSUI\build\x64\Release\Mods\AtriVITS\AtriVITS_LJS.pt
121
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Libtorch C++ Infer/toLibTorch.ipynb DELETED
@@ -1,142 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%matplotlib inline\n",
10
- "import matplotlib.pyplot as plt\n",
11
- "import IPython.display as ipd\n",
12
- "\n",
13
- "import os\n",
14
- "import json\n",
15
- "import math\n",
16
- "import torch\n",
17
- "from torch import nn\n",
18
- "from torch.nn import functional as F\n",
19
- "from torch.utils.data import DataLoader\n",
20
- "\n",
21
- "import ../commons\n",
22
- "import ../utils\n",
23
- "from ../data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate\n",
24
- "from ../models import SynthesizerTrn\n",
25
- "from ../text.symbols import symbols\n",
26
- "from ../text import text_to_sequence\n",
27
- "\n",
28
- "from scipy.io.wavfile import write\n",
29
- "\n",
30
- "\n",
31
- "def get_text(text, hps):\n",
32
- " text_norm = text_to_sequence(text, hps.data.text_cleaners)\n",
33
- " if hps.data.add_blank:\n",
34
- " text_norm = commons.intersperse(text_norm, 0)\n",
35
- " text_norm = torch.LongTensor(text_norm)\n",
36
- " return text_norm"
37
- ]
38
- },
39
- {
40
- "cell_type": "code",
41
- "execution_count": null,
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "#############################################################\n",
46
- "# #\n",
47
- "# Single Speakers #\n",
48
- "# #\n",
49
- "#############################################################"
50
- ]
51
- },
52
- {
53
- "cell_type": "code",
54
- "execution_count": null,
55
- "metadata": {},
56
- "outputs": [],
57
- "source": [
58
- "hps = utils.get_hparams_from_file(\"configs/XXX.json\") #将\"\"内的内容修改为你的模型路径与config路径\n",
59
- "net_g = SynthesizerTrn(\n",
60
- " len(symbols),\n",
61
- " hps.data.filter_length // 2 + 1,\n",
62
- " hps.train.segment_size // hps.data.hop_length,\n",
63
- " **hps.model).cuda()\n",
64
- "_ = net_g.eval()\n",
65
- "\n",
66
- "_ = utils.load_checkpoint(\"/path/to/model.pth\", net_g, None)"
67
- ]
68
- },
69
- {
70
- "cell_type": "code",
71
- "execution_count": null,
72
- "metadata": {},
73
- "outputs": [],
74
- "source": [
75
- "stn_tst = get_text(\"こんにちは\", hps)\n",
76
- "with torch.no_grad():\n",
77
- " x_tst = stn_tst.cuda().unsqueeze(0)\n",
78
- " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
79
- " traced_mod = torch.jit.trace(net_g,(x_tst, x_tst_lengths,sid))\n",
80
- " torch.jit.save(traced_mod,\"OUTPUTLIBTORCHMODEL.pt\")\n",
81
- " audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
82
- "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
83
- ]
84
- },
85
- {
86
- "cell_type": "code",
87
- "execution_count": null,
88
- "metadata": {},
89
- "outputs": [],
90
- "source": [
91
- "#############################################################\n",
92
- "# #\n",
93
- "# Multiple Speakers #\n",
94
- "# #\n",
95
- "#############################################################"
96
- ]
97
- },
98
- {
99
- "cell_type": "code",
100
- "execution_count": null,
101
- "metadata": {},
102
- "outputs": [],
103
- "source": [
104
- "hps = utils.get_hparams_from_file(\"./configs/XXX.json\") #将\"\"内的内容修改为你的模型路径与config路径\n",
105
- "net_g = SynthesizerTrn(\n",
106
- " len(symbols),\n",
107
- " hps.data.filter_length // 2 + 1,\n",
108
- " hps.train.segment_size // hps.data.hop_length,\n",
109
- " n_speakers=hps.data.n_speakers,\n",
110
- " **hps.model).cuda()\n",
111
- "_ = net_g.eval()\n",
112
- "\n",
113
- "_ = utils.load_checkpoint(\"/path/to/model.pth\", net_g, None)"
114
- ]
115
- },
116
- {
117
- "cell_type": "code",
118
- "execution_count": null,
119
- "metadata": {},
120
- "outputs": [],
121
- "source": [
122
- "stn_tst = get_text(\"こんにちは\", hps)\n",
123
- "with torch.no_grad():\n",
124
- " x_tst = stn_tst.cuda().unsqueeze(0)\n",
125
- " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
126
- " sid = torch.LongTensor([4]).cuda()\n",
127
- " traced_mod = torch.jit.trace(net_g,(x_tst, x_tst_lengths,sid))\n",
128
- " torch.jit.save(traced_mod,\"OUTPUTLIBTORCHMODEL.pt\")\n",
129
- " audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
130
- "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
131
- ]
132
- }
133
- ],
134
- "metadata": {
135
- "language_info": {
136
- "name": "python"
137
- },
138
- "orig_nbformat": 4
139
- },
140
- "nbformat": 4,
141
- "nbformat_minor": 2
142
- }