CosyVoice commited on
Commit
122df8c
·
1 Parent(s): bcda6d8

set onnx to false as last chunk rtf unstable

Browse files
.github/workflows/lint.yml CHANGED
@@ -2,6 +2,7 @@ name: Lint
2
 
3
  on:
4
  pull_request:
 
5
 
6
  jobs:
7
  quick-checks:
 
2
 
3
  on:
4
  pull_request:
5
+ push:
6
 
7
  jobs:
8
  quick-checks:
cosyvoice/cli/cosyvoice.py CHANGED
@@ -23,7 +23,7 @@ from cosyvoice.utils.file_utils import logging
23
 
24
  class CosyVoice:
25
 
26
- def __init__(self, model_dir, load_jit=True, load_onnx=True):
27
  instruct = True if '-Instruct' in model_dir else False
28
  self.model_dir = model_dir
29
  if not os.path.exists(model_dir):
 
23
 
24
  class CosyVoice:
25
 
26
+ def __init__(self, model_dir, load_jit=True, load_onnx=False):
27
  instruct = True if '-Instruct' in model_dir else False
28
  self.model_dir = model_dir
29
  if not os.path.exists(model_dir):
cosyvoice/cli/model.py CHANGED
@@ -43,7 +43,6 @@ class CosyVoiceModel:
43
  self.stream_scale_factor = 1
44
  assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
45
  self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
46
- self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
47
  self.lock = threading.Lock()
48
  # dict used to store session related variable
49
  self.tts_speech_token_dict = {}
@@ -93,32 +92,31 @@ class CosyVoiceModel:
93
  self.llm_end_dict[uuid] = True
94
 
95
  def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
96
- with self.flow_hift_context:
97
- tts_mel = self.flow.inference(token=token.to(self.device),
98
- token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
99
- prompt_token=prompt_token.to(self.device),
100
- prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
101
- prompt_feat=prompt_feat.to(self.device),
102
- prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
103
- embedding=embedding.to(self.device))
104
- # mel overlap fade in out
105
- if self.mel_overlap_dict[uuid] is not None:
106
- tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
107
- # append hift cache
108
- if self.hift_cache_dict[uuid] is not None:
109
- hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
110
- tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
111
- else:
112
- hift_cache_source = torch.zeros(1, 1, 0)
113
- # keep overlap mel and hift cache
114
- if finalize is False:
115
- self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
116
- tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
117
- tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
118
- self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
119
- tts_speech = tts_speech[:, :-self.source_cache_len]
120
- else:
121
- tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
122
  return tts_speech
123
 
124
  def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
@@ -139,13 +137,12 @@ class CosyVoiceModel:
139
  time.sleep(0.1)
140
  if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
141
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
142
- with self.flow_hift_context:
143
- this_tts_speech = self.token2wav(token=this_tts_speech_token,
144
- prompt_token=flow_prompt_speech_token,
145
- prompt_feat=prompt_speech_feat,
146
- embedding=flow_embedding,
147
- uuid=this_uuid,
148
- finalize=False)
149
  yield {'tts_speech': this_tts_speech.cpu()}
150
  with self.lock:
151
  self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
@@ -156,30 +153,26 @@ class CosyVoiceModel:
156
  p.join()
157
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
158
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
159
- with self.flow_hift_context:
160
- this_tts_speech = self.token2wav(token=this_tts_speech_token,
161
- prompt_token=flow_prompt_speech_token,
162
- prompt_feat=prompt_speech_feat,
163
- embedding=flow_embedding,
164
- uuid=this_uuid,
165
- finalize=True)
166
  yield {'tts_speech': this_tts_speech.cpu()}
167
  else:
168
  # deal with all tokens
169
  p.join()
170
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
171
- with self.flow_hift_context:
172
- this_tts_speech = self.token2wav(token=this_tts_speech_token,
173
- prompt_token=flow_prompt_speech_token,
174
- prompt_feat=prompt_speech_feat,
175
- embedding=flow_embedding,
176
- uuid=this_uuid,
177
- finalize=True)
178
  yield {'tts_speech': this_tts_speech.cpu()}
179
  with self.lock:
180
  self.tts_speech_token_dict.pop(this_uuid)
181
  self.llm_end_dict.pop(this_uuid)
182
  self.mel_overlap_dict.pop(this_uuid)
183
  self.hift_cache_dict.pop(this_uuid)
184
- if torch.cuda.is_available():
185
- torch.cuda.synchronize()
 
43
  self.stream_scale_factor = 1
44
  assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
45
  self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
 
46
  self.lock = threading.Lock()
47
  # dict used to store session related variable
48
  self.tts_speech_token_dict = {}
 
92
  self.llm_end_dict[uuid] = True
93
 
94
  def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
95
+ tts_mel = self.flow.inference(token=token.to(self.device),
96
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
97
+ prompt_token=prompt_token.to(self.device),
98
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
99
+ prompt_feat=prompt_feat.to(self.device),
100
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
101
+ embedding=embedding.to(self.device))
102
+ # mel overlap fade in out
103
+ if self.mel_overlap_dict[uuid] is not None:
104
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
105
+ # append hift cache
106
+ if self.hift_cache_dict[uuid] is not None:
107
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
108
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
109
+ else:
110
+ hift_cache_source = torch.zeros(1, 1, 0)
111
+ # keep overlap mel and hift cache
112
+ if finalize is False:
113
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
114
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
115
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
116
+ self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
117
+ tts_speech = tts_speech[:, :-self.source_cache_len]
118
+ else:
119
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
 
120
  return tts_speech
121
 
122
  def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
 
137
  time.sleep(0.1)
138
  if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
139
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
140
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
141
+ prompt_token=flow_prompt_speech_token,
142
+ prompt_feat=prompt_speech_feat,
143
+ embedding=flow_embedding,
144
+ uuid=this_uuid,
145
+ finalize=False)
 
146
  yield {'tts_speech': this_tts_speech.cpu()}
147
  with self.lock:
148
  self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
 
153
  p.join()
154
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
155
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
156
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
157
+ prompt_token=flow_prompt_speech_token,
158
+ prompt_feat=prompt_speech_feat,
159
+ embedding=flow_embedding,
160
+ uuid=this_uuid,
161
+ finalize=True)
 
162
  yield {'tts_speech': this_tts_speech.cpu()}
163
  else:
164
  # deal with all tokens
165
  p.join()
166
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
167
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
168
+ prompt_token=flow_prompt_speech_token,
169
+ prompt_feat=prompt_speech_feat,
170
+ embedding=flow_embedding,
171
+ uuid=this_uuid,
172
+ finalize=True)
 
173
  yield {'tts_speech': this_tts_speech.cpu()}
174
  with self.lock:
175
  self.tts_speech_token_dict.pop(this_uuid)
176
  self.llm_end_dict.pop(this_uuid)
177
  self.mel_overlap_dict.pop(this_uuid)
178
  self.hift_cache_dict.pop(this_uuid)