CosyVoice commited on
Commit
ffa28e3
·
1 Parent(s): 8555ab4

update token args

Browse files
Files changed (2) hide show
  1. cosyvoice/cli/model.py +3 -6
  2. cosyvoice/llm/llm.py +1 -1
cosyvoice/cli/model.py CHANGED
@@ -31,8 +31,8 @@ class CosyVoiceModel:
31
  self.llm = llm
32
  self.flow = flow
33
  self.hift = hift
34
- self.token_min_hop_len = 100
35
- self.token_max_hop_len = 200
36
  self.token_overlap_len = 20
37
  # mel fade in out
38
  self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
@@ -87,10 +87,7 @@ class CosyVoiceModel:
87
  prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
88
  prompt_speech_token=llm_prompt_speech_token.to(self.device),
89
  prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
90
- embedding=llm_embedding.to(self.device).half(),
91
- sampling=25,
92
- max_token_text_ratio=30,
93
- min_token_text_ratio=3):
94
  self.tts_speech_token_dict[uuid].append(i)
95
  self.llm_end_dict[uuid] = True
96
 
 
31
  self.llm = llm
32
  self.flow = flow
33
  self.hift = hift
34
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
35
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
36
  self.token_overlap_len = 20
37
  # mel fade in out
38
  self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
 
87
  prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
88
  prompt_speech_token=llm_prompt_speech_token.to(self.device),
89
  prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
90
+ embedding=llm_embedding.to(self.device).half()):
 
 
 
91
  self.tts_speech_token_dict[uuid].append(i)
92
  self.llm_end_dict[uuid] = True
93
 
cosyvoice/llm/llm.py CHANGED
@@ -197,7 +197,7 @@ class TransformerLM(torch.nn.Module):
197
  offset = 0
198
  att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
199
  for i in range(max_len):
200
- y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
201
  att_cache=att_cache, cnn_cache=cnn_cache,
202
  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
203
  device=lm_input.device)).to(torch.bool))
 
197
  offset = 0
198
  att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
199
  for i in range(max_len):
200
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
201
  att_cache=att_cache, cnn_cache=cnn_cache,
202
  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
203
  device=lm_input.device)).to(torch.bool))