jason-on-salt-a40 commited on
Commit
78774ba
·
1 Parent(s): c1908d8

hf model download

Browse files
Files changed (3) hide show
  1. app.py +11 -16
  2. models/voicecraft.py +8 -2
  3. requirements.txt +2 -1
app.py CHANGED
@@ -93,27 +93,22 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
93
  transcribe_model = WhisperxModel(whisper_model_name, align_model)
94
 
95
  voicecraft_name = f"{voicecraft_model_name}.pth"
96
- ckpt_fn = f"{MODELS_PATH}/{voicecraft_name}"
 
 
 
 
97
  encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
98
- if not os.path.exists(ckpt_fn):
99
- os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
100
- os.system(f"mv {voicecraft_name}\?download\=true {MODELS_PATH}/{voicecraft_name}")
101
  if not os.path.exists(encodec_fn):
102
  os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
103
- os.system(f"mv encodec_4cb2048_giga.th {MODELS_PATH}/encodec_4cb2048_giga.th")
104
 
105
- ckpt = torch.load(ckpt_fn, map_location="cpu")
106
- model = voicecraft.VoiceCraft(ckpt["config"])
107
- model.load_state_dict(ckpt["model"])
108
- model.to(device)
109
- model.eval()
110
  voicecraft_model = {
111
- "ckpt": ckpt,
 
112
  "model": model,
113
  "text_tokenizer": TextTokenizer(backend="espeak"),
114
  "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
115
  }
116
-
117
  return gr.Accordion()
118
 
119
 
@@ -255,8 +250,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
255
 
256
  prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
257
  _, gen_audio = inference_one_sample(voicecraft_model["model"],
258
- voicecraft_model["ckpt"]["config"],
259
- voicecraft_model["ckpt"]["phn2num"],
260
  voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
261
  audio_path, target_transcript, device, decode_config,
262
  prompt_end_frame)
@@ -284,8 +279,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
284
  mask_interval = torch.LongTensor(mask_interval)
285
 
286
  _, gen_audio = inference_one_sample(voicecraft_model["model"],
287
- voicecraft_model["ckpt"]["config"],
288
- voicecraft_model["ckpt"]["phn2num"],
289
  voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
290
  audio_path, target_transcript, mask_interval, device, decode_config)
291
  gen_audio = gen_audio[0].cpu()
 
93
  transcribe_model = WhisperxModel(whisper_model_name, align_model)
94
 
95
  voicecraft_name = f"{voicecraft_model_name}.pth"
96
+ model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
97
+ phn2num = model.args.phn2num
98
+ config = model.args
99
+ model.to(device)
100
+
101
  encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
 
 
 
102
  if not os.path.exists(encodec_fn):
103
  os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
 
104
 
 
 
 
 
 
105
  voicecraft_model = {
106
+ "config": config,
107
+ "phn2num": phn2num,
108
  "model": model,
109
  "text_tokenizer": TextTokenizer(backend="espeak"),
110
  "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
111
  }
 
112
  return gr.Accordion()
113
 
114
 
 
250
 
251
  prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
252
  _, gen_audio = inference_one_sample(voicecraft_model["model"],
253
+ voicecraft_model["config"],
254
+ voicecraft_model["phn2num"],
255
  voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
256
  audio_path, target_transcript, device, decode_config,
257
  prompt_end_frame)
 
279
  mask_interval = torch.LongTensor(mask_interval)
280
 
281
  _, gen_audio = inference_one_sample(voicecraft_model["model"],
282
+ voicecraft_model["config"],
283
+ voicecraft_model["phn2num"],
284
  voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
285
  audio_path, target_transcript, mask_interval, device, decode_config)
286
  gen_audio = gen_audio[0].cpu()
models/voicecraft.py CHANGED
@@ -17,7 +17,8 @@ from .modules.transformer import (
17
  TransformerEncoderLayer,
18
  )
19
  from .codebooks_patterns import DelayedPatternProvider
20
-
 
21
  def top_k_top_p_filtering(
22
  logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
23
  ):
@@ -1403,4 +1404,9 @@ class VoiceCraft(nn.Module):
1403
  res = res - int(self.args.n_special)
1404
  flatten_gen = flatten_gen - int(self.args.n_special)
1405
 
1406
- return res, flatten_gen[0].unsqueeze(0)
 
 
 
 
 
 
17
  TransformerEncoderLayer,
18
  )
19
  from .codebooks_patterns import DelayedPatternProvider
20
+ from huggingface_hub import PyTorchModelHubMixin
21
+ from argparse import Namespace
22
  def top_k_top_p_filtering(
23
  logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
24
  ):
 
1404
  res = res - int(self.args.n_special)
1405
  flatten_gen = flatten_gen - int(self.args.n_special)
1406
 
1407
+ return res, flatten_gen[0].unsqueeze(0)
1408
+
1409
+ class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]):
1410
+ def __init__(self, config: dict):
1411
+ args = Namespace(**config)
1412
+ super().__init__(args)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ nltk>=3.8.1
5
  openai-whisper>=20231117
6
  spaces
7
  aeneas==1.7.3.0
8
- whisperx==3.1.1
 
 
5
  openai-whisper>=20231117
6
  spaces
7
  aeneas==1.7.3.0
8
+ whisperx==3.1.1
9
+ huggingface-hub==0.22.2