ylacombe HF staff commited on
Commit
516bd70
1 Parent(s): 2c30452

Add Number Normalization and other fix

Browse files

For reference, here's the classic mistakes the model seems to make at the moment:
- Can't pronounce abbreviations, seems ok if we separate letters
- Mispronounce prompts' last word if it doesn't end with a punctuation mark
- Can't pronounce numbers
- Mistakes hyphens for long pauses (e.g text-to-speech)

This PR shows how to correct most of these issues with proper prompt normalization

Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -1,6 +1,10 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
 
 
 
 
4
 
5
  from parler_tts import ParlerTTSForConditionalGeneration
6
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
@@ -38,11 +42,31 @@ examples = [
38
  ],
39
  ]
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- @spaces.GPU
43
  def gen_tts(text, description):
44
  inputs = tokenizer(description, return_tensors="pt").to(device)
45
- prompt = tokenizer(text, return_tensors="pt").to(device)
46
 
47
  set_seed(SEED)
48
  generation = model.generate(
@@ -140,9 +164,9 @@ with gr.Blocks(css=css) as block:
140
  and torch compile, that will improve the latency by 2-4x. If you want to find out more about how this model was trained and even fine-tune it yourself, check-out the
141
  <a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a> repository on GitHub.</p>
142
 
143
- <p>The Parler-TTS codebase and its associated checkpoints are licensed under <a href='https://github.com/huggingface/parler-tts?tab=Apache-2.0-1-ov-file#readme'> Apache 2.0</a>.</p>
144
  """
145
  )
146
 
147
  block.queue()
148
- block.launch(share=True)
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
5
+ from string import punctuation
6
+ import re
7
+
8
 
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
 
42
  ],
43
  ]
44
 
45
+ number_normalizer = EnglishNumberNormalizer()
46
+
47
+ def preprocess(text):
48
+ text = number_normalizer(text).strip()
49
+ text = text.replace("-", " ")
50
+ if text[-1] not in punctuation:
51
+ text = f"{text}."
52
+
53
+ abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
54
+
55
+ def separate_abb(chunk):
56
+ chunk = chunk.replace(".","")
57
+ print(chunk)
58
+ return " ".join(chunk)
59
+
60
+ abbreviations = re.findall(abbreviations_pattern, text)
61
+ for abv in abbreviations:
62
+ if abv in text:
63
+ text = text.replace(abv, separate_abb(abv))
64
+ return text
65
+
66
 
 
67
  def gen_tts(text, description):
68
  inputs = tokenizer(description, return_tensors="pt").to(device)
69
+ prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
70
 
71
  set_seed(SEED)
72
  generation = model.generate(
 
164
  and torch compile, that will improve the latency by 2-4x. If you want to find out more about how this model was trained and even fine-tune it yourself, check-out the
165
  <a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a> repository on GitHub.</p>
166
 
167
+ <p>The Parler-TTS codebase and its associated checkpoints are licensed under <a href='https://github.com/huggingface/parler-tts?tab=Apache-2.0-1-ov-file#readme'> Apache 2.0</a></p>.
168
  """
169
  )
170
 
171
  block.queue()
172
+ block.launch(share=True)