Update app.py
Browse files
app.py
CHANGED
@@ -5,70 +5,81 @@ from parler_tts import ParlerTTSForConditionalGeneration
|
|
5 |
from transformers import AutoTokenizer
|
6 |
import gradio as gr
|
7 |
import re
|
|
|
8 |
|
9 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class EnglishNumberNormalizer:
|
11 |
def __call__(self, text):
|
12 |
-
#
|
|
|
|
|
|
|
|
|
13 |
return text
|
14 |
|
15 |
number_normalizer = EnglishNumberNormalizer()
|
16 |
|
|
|
17 |
def preprocess(text):
|
|
|
18 |
text = number_normalizer(text).strip()
|
19 |
-
text = text.replace("-", " ")
|
20 |
-
if text[-1] not in ".!?":
|
21 |
-
text = f"{text}."
|
22 |
|
23 |
-
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
|
|
|
|
|
29 |
abbreviations = re.findall(abbreviations_pattern, text)
|
30 |
for abv in abbreviations:
|
31 |
-
|
32 |
-
|
|
|
|
|
33 |
return text
|
34 |
|
35 |
-
# Vérification de la disponibilité de CUDA
|
36 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
37 |
-
|
38 |
-
# Chargement du modèle et du tokenizer
|
39 |
-
try:
|
40 |
-
model = ParlerTTSForConditionalGeneration.from_pretrained("CONCREE/Adia_TTS", torch_dtype=torch.float16).to(device)
|
41 |
-
tokenizer = AutoTokenizer.from_pretrained("CONCREE/Adia_TTS")
|
42 |
-
except Exception as e:
|
43 |
-
raise RuntimeError(f"Erreur lors du chargement du modèle : {e}")
|
44 |
-
|
45 |
# Texte et description par défaut
|
46 |
-
default_prompt = "
|
47 |
default_description = """A crystal clear and distinct voice, with a moderate reading rate that facilitates understanding. The tone is monotonous, without variations or inflections, which provides a uniform listening experience. The voice is free of background noise and allows for continuous reading, without inappropriate pauses, thus ensuring a constant and pleasant flow."""
|
48 |
|
|
|
|
|
49 |
# Fonction pour générer l'audio sans segmentation
|
50 |
def generate_audio(prompt, description):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
except Exception as e:
|
71 |
-
raise RuntimeError(f"Erreur lors de la génération de l'audio : {e}")
|
72 |
|
73 |
# Fonction pour mettre à jour le compteur de caractères
|
74 |
def update_char_counter(text):
|
@@ -79,7 +90,7 @@ def update_char_counter(text):
|
|
79 |
def create_interface():
|
80 |
with gr.Blocks() as demo:
|
81 |
# Ajouter une image ou un logo
|
82 |
-
gr.Markdown("")
|
83 |
|
84 |
# Titre et description
|
85 |
gr.Markdown("# 🌟 Bienvenue sur Adia TTS 🌟")
|
@@ -104,7 +115,7 @@ def create_interface():
|
|
104 |
default_description,
|
105 |
],
|
106 |
[
|
107 |
-
"""Entreprenariat ci Senegal dafa am solo lool ci yokkuteg koom-koom, di gëna yokk liggéey ak indi gis-gis yu bees ci dëkk bi. Ndaw yi am këru liggéey dañuy am xéewal yu amul fenn ndax ecosystem bi dafay màgg.""",
|
108 |
default_description,
|
109 |
],
|
110 |
],
|
|
|
5 |
from transformers import AutoTokenizer
|
6 |
import gradio as gr
|
7 |
import re
|
8 |
+
from num2words import num2words
|
9 |
|
10 |
+
# Vérification de la disponibilité de CUDA
|
11 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
12 |
+
|
13 |
+
# Chargement du modèle et du tokenizer
|
14 |
+
try:
|
15 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained("CONCREE/Adia_TTS", torch_dtype=torch.float16).to(device)
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained("CONCREE/Adia_TTS")
|
17 |
+
except Exception as e:
|
18 |
+
raise RuntimeError(f"Erreur lors du chargement du modèle : {e}")
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
# Normalisation des nombres
|
24 |
class EnglishNumberNormalizer:
|
25 |
def __call__(self, text):
|
26 |
+
# Trouver tous les nombres dans le texte
|
27 |
+
numbers = re.findall(r'\d+', text)
|
28 |
+
for number in numbers:
|
29 |
+
# Convertir le nombre en mots
|
30 |
+
text = text.replace(number, num2words(int(number), lang='fr'))
|
31 |
return text
|
32 |
|
33 |
number_normalizer = EnglishNumberNormalizer()
|
34 |
|
35 |
+
# Fonction de prétraitement
|
36 |
def preprocess(text):
|
37 |
+
# Normaliser les nombres
|
38 |
text = number_normalizer(text).strip()
|
|
|
|
|
|
|
39 |
|
40 |
+
# Remplacer les tirets par des espaces
|
41 |
+
text = text.replace("-", " ")
|
42 |
|
43 |
+
# Ajouter un point à la fin si le texte ne se termine pas par une ponctuation
|
44 |
+
if not text.endswith(('.', '!', '?')):
|
45 |
+
text += "."
|
46 |
|
47 |
+
# Traiter les abréviations
|
48 |
+
abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
|
49 |
abbreviations = re.findall(abbreviations_pattern, text)
|
50 |
for abv in abbreviations:
|
51 |
+
# Séparer les lettres des abréviations (par exemple, "U.S.A." -> "U S A")
|
52 |
+
separated_abv = " ".join(abv.replace(".", ""))
|
53 |
+
text = text.replace(abv, separated_abv)
|
54 |
+
|
55 |
return text
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# Texte et description par défaut
|
58 |
+
default_prompt = "Abdoul nena souba dinagnou am reunion pour waxtaan li des"
|
59 |
default_description = """A crystal clear and distinct voice, with a moderate reading rate that facilitates understanding. The tone is monotonous, without variations or inflections, which provides a uniform listening experience. The voice is free of background noise and allows for continuous reading, without inappropriate pauses, thus ensuring a constant and pleasant flow."""
|
60 |
|
61 |
+
|
62 |
+
|
63 |
# Fonction pour générer l'audio sans segmentation
|
64 |
def generate_audio(prompt, description):
|
65 |
+
# Prétraiter le texte
|
66 |
+
prompt = preprocess(prompt)
|
67 |
+
|
68 |
+
# Génération des IDs d'entrée
|
69 |
+
input_ids = tokenizer(description.strip(), return_tensors="pt").input_ids.to(device)
|
70 |
+
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
|
|
71 |
|
72 |
+
# Générer l'audio
|
73 |
+
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
74 |
+
audio_arr = generation.cpu().numpy().squeeze() # Transformer en tableau numpy
|
75 |
|
76 |
+
# Taux d'échantillonnage
|
77 |
+
sampling_rate = model.config.sampling_rate
|
78 |
|
79 |
+
# Normaliser l'audio
|
80 |
+
audio_arr = audio_arr / np.max(np.abs(audio_arr))
|
81 |
|
82 |
+
return sampling_rate, audio_arr
|
|
|
|
|
83 |
|
84 |
# Fonction pour mettre à jour le compteur de caractères
|
85 |
def update_char_counter(text):
|
|
|
90 |
def create_interface():
|
91 |
with gr.Blocks() as demo:
|
92 |
# Ajouter une image ou un logo
|
93 |
+
gr.Markdown("") # Remplacez l'URL par le chemin de votre image
|
94 |
|
95 |
# Titre et description
|
96 |
gr.Markdown("# 🌟 Bienvenue sur Adia TTS 🌟")
|
|
|
115 |
default_description,
|
116 |
],
|
117 |
[
|
118 |
+
"""Entreprenariat ci Senegal dafa am solo lool ci yokkuteg koom-koom, di gëna yokk liggéey ak indi gis-gis yu bees ci dëkk bi. Ndaw yi am këru liggéey dañuy am xéewal yu amul fenn ndax ecosystem bi dafay màgg, te inisiatiif yu réew mi ak yu prive yi ñoo leen di jàppale.""",
|
119 |
default_description,
|
120 |
],
|
121 |
],
|