tjysdsg commited on
Commit
4defacc
·
1 Parent(s): 677f9a8

Try to implement s2st

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +58 -2
  3. output.wav +0 -0
  4. s2st_inference.py +121 -0
.gitignore CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  data
2
 
3
  # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,jetbrains+all,visualstudiocode,python,jupyternotebooks
 
1
+ *.wav
2
+ model
3
+ vocoder
4
  data
5
 
6
  # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,jetbrains+all,visualstudiocode,python,jupyternotebooks
app.py CHANGED
@@ -4,10 +4,23 @@ import numpy as np
4
  import torch
5
  import torchaudio
6
  from typing import Tuple, Optional
 
 
7
 
8
  SAMPLE_RATE = 16000
9
  MAX_INPUT_LENGTH = 60 # seconds
10
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def s2st(
13
  audio_source: str,
@@ -32,9 +45,52 @@ def s2st(
32
 
33
  wav = wav[0] # mono
34
 
35
- # TODO: translate wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  output_path = 'output.wav'
37
- torchaudio.save(output_path, wav.unsqueeze(0), SAMPLE_RATE)
 
 
 
 
 
38
 
39
  return output_path, f'Source: {audio_source}'
40
 
 
4
  import torch
5
  import torchaudio
6
  from typing import Tuple, Optional
7
+ import soundfile as sf
8
+ from s2st_inference import s2st_inference
9
 
10
  SAMPLE_RATE = 16000
11
  MAX_INPUT_LENGTH = 60 # seconds
12
 
13
+ S2UT_TAG = 'espnet/jiyang_tang_cvss-c_es-en_discrete_unit'
14
+ S2UT_DIR = 'model'
15
+ VOCODER_TAG = 'espnet/cvss-c_en_wavegan_hubert_vocoder'
16
+ VOCODER_DIR = 'vocoder'
17
+
18
+
19
+ def download_model(tag: str, out_dir: str):
20
+ from huggingface_hub import snapshot_download
21
+
22
+ return snapshot_download(repo_id=tag, local_dir=out_dir)
23
+
24
 
25
  def s2st(
26
  audio_source: str,
 
45
 
46
  wav = wav[0] # mono
47
 
48
+ # Download models
49
+ os.makedirs(S2UT_DIR, exist_ok=True)
50
+ os.makedirs(VOCODER_DIR, exist_ok=True)
51
+ s2ut_path = download_model(S2UT_TAG, S2UT_DIR)
52
+ vocoder_path = download_model(VOCODER_TAG, VOCODER_DIR)
53
+
54
+ # Temporary change cwd to model dir so that it loads correctly
55
+ cwd = os.getcwd()
56
+ os.chdir(s2ut_path)
57
+
58
+ # Translate wav
59
+ out_wav = s2st_inference(
60
+ wav,
61
+ train_config=os.path.join(
62
+ s2ut_path,
63
+ 'exp',
64
+ 's2st_train_s2st_discrete_unit_raw_fbank_es_en',
65
+ 'config.yaml',
66
+ ),
67
+ model_file=os.path.join(
68
+ s2ut_path,
69
+ 'exp',
70
+ 's2st_train_s2st_discrete_unit_raw_fbank_es_en',
71
+ '500epoch.pth',
72
+ ),
73
+ vocoder_file=os.path.join(
74
+ vocoder_path,
75
+ 'checkpoint-400000steps.pkl',
76
+ ),
77
+ vocoder_config=os.path.join(
78
+ vocoder_path,
79
+ 'config.yml',
80
+ ),
81
+ )
82
+
83
+ # Restore working directory
84
+ os.chdir(cwd)
85
+
86
+ # Save result
87
  output_path = 'output.wav'
88
+ sf.write(
89
+ output_path,
90
+ out_wav,
91
+ 16000,
92
+ "PCM_16",
93
+ )
94
 
95
  return output_path, f'Source: {audio_source}'
96
 
output.wav DELETED
Binary file (136 kB)
 
s2st_inference.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import shutil
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+ from typeguard import check_argument_types
12
+ from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
13
+ from espnet2.bin.s2st_inference import Speech2Speech
14
+
15
+
16
+ def s2st_inference(
17
+ speech: torch.Tensor,
18
+ ngpu: int = 0,
19
+ seed: int = 2023,
20
+ log_level: Union[int, str] = 'INFO',
21
+ train_config: Optional[str] = None,
22
+ model_file: Optional[str] = None,
23
+ threshold: float = 0.5,
24
+ minlenratio: float = 0,
25
+ maxlenratio: float = 10.0,
26
+ st_subtask_minlenratio: float = 0,
27
+ st_subtask_maxlenratio: float = 1.5,
28
+ use_teacher_forcing: bool = False,
29
+ use_att_constraint: bool = False,
30
+ backward_window: int = 1,
31
+ forward_window: int = 3,
32
+ always_fix_seed: bool = False,
33
+ beam_size: int = 5,
34
+ penalty: float = 0,
35
+ st_subtask_beam_size: int = 5,
36
+ st_subtask_penalty: float = 0,
37
+ st_subtask_token_type: Optional[str] = None,
38
+ st_subtask_bpemodel: Optional[str] = None,
39
+ vocoder_config: Optional[str] = None,
40
+ vocoder_file: Optional[str] = None,
41
+ vocoder_tag: Optional[str] = None,
42
+ ):
43
+ """Run text-to-speech inference."""
44
+ assert check_argument_types()
45
+ if ngpu > 1:
46
+ raise NotImplementedError("only single GPU decoding is supported")
47
+ logging.basicConfig(
48
+ level=log_level,
49
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
50
+ )
51
+
52
+ if ngpu >= 1:
53
+ device = "cuda"
54
+ else:
55
+ device = "cpu"
56
+
57
+ # 1. Set random-seed
58
+ set_all_random_seed(seed)
59
+
60
+ # 2. Build model
61
+ speech2speech_kwargs = dict(
62
+ train_config=train_config,
63
+ model_file=model_file,
64
+ threshold=threshold,
65
+ maxlenratio=maxlenratio,
66
+ minlenratio=minlenratio,
67
+ st_subtask_maxlenratio=st_subtask_maxlenratio,
68
+ st_subtask_minlenratio=st_subtask_minlenratio,
69
+ use_teacher_forcing=use_teacher_forcing,
70
+ use_att_constraint=use_att_constraint,
71
+ backward_window=backward_window,
72
+ forward_window=forward_window,
73
+ beam_size=beam_size,
74
+ penalty=penalty,
75
+ st_subtask_beam_size=st_subtask_beam_size,
76
+ st_subtask_penalty=st_subtask_penalty,
77
+ st_subtask_token_type=st_subtask_token_type,
78
+ st_subtask_bpemodel=st_subtask_bpemodel,
79
+ vocoder_config=vocoder_config,
80
+ vocoder_file=vocoder_file,
81
+ device=device,
82
+ seed=seed,
83
+ always_fix_seed=always_fix_seed,
84
+ )
85
+ speech2speech = Speech2Speech.from_pretrained(
86
+ vocoder_tag=vocoder_tag,
87
+ **speech2speech_kwargs,
88
+ )
89
+
90
+ start_time = time.perf_counter()
91
+
92
+ speech_lengths = torch.as_tensor([speech.shape[0]])
93
+ output_dict = speech2speech(speech.unsqueeze(0), speech_lengths)
94
+
95
+ insize = speech.size(0) + 1
96
+ # standard speech2mel model case
97
+ feat_gen = output_dict["feat_gen"]
98
+ logging.info(
99
+ f"inference speed = {int(feat_gen.size(0)) / (time.perf_counter() - start_time):.1f} frames / sec."
100
+ )
101
+ logging.info(f"(size:{insize}->{feat_gen.size(0)})")
102
+ if feat_gen.size(0) == insize * maxlenratio:
103
+ logging.warning(f"output length reaches maximum length.")
104
+
105
+ feat_gen = output_dict["feat_gen"].cpu().numpy()
106
+ if output_dict.get("feat_gen_denorm") is not None:
107
+ feat_gen_denorm = output_dict["feat_gen_denorm"].cpu().numpy()
108
+
109
+ assert 'wav' in output_dict
110
+ wav = output_dict["wav"].cpu().numpy()
111
+ logging.info(f"wav {len(wav)}")
112
+
113
+ return wav
114
+
115
+ # if output_dict.get("st_subtask_token") is not None:
116
+ # writer["token"][key] = " ".join(output_dict["st_subtask_token"])
117
+ # writer["token_int"][key] == " ".join(
118
+ # map(str, output_dict["st_subtask_token_int"])
119
+ # )
120
+ # if output_dict.get("st_subtask_text") is not None:
121
+ # writer["text"][key] = output_dict["st_subtask_text"]