File size: 5,654 Bytes
496bf8a
efcdb1c
ab3a30c
 
 
 
 
 
 
 
496bf8a
ab3a30c
e85fa31
72c65b6
d536921
72c65b6
 
ab3a30c
efcdb1c
ab3a30c
 
 
 
72c65b6
 
 
 
 
 
d86bc7f
290deb7
bddd843
553f0a4
ab3a30c
 
 
 
 
290deb7
496bf8a
 
 
b4e6550
82022e9
496bf8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab3a30c
 
 
 
6a42799
b4e6550
d86bc7f
 
 
 
 
0ceab50
d86bc7f
8a4d97a
b4e6550
 
 
914c603
b4e6550
 
d86bc7f
72c65b6
82022e9
72c65b6
93acf16
c8a6713
 
b4e6550
ee4aecd
7caca05
 
24baff1
 
 
36654bb
 
 
 
b4e6550
 
5086b00
b4e6550
82022e9
 
 
b4e6550
82022e9
978cac7
c8a6713
ab3a30c
a7c5b39
ab3a30c
 
d86bc7f
 
ab3a30c
 
73ce57b
c3db1ad
b4e6550
e055ee5
c3db1ad
73ce57b
d86bc7f
 
978cac7
d86bc7f
 
 
d6fc925
82022e9
b4e6550
 
ab3a30c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import io
import math
from typing import Optional

import numpy as np
import spaces
import gradio as gr
import torch

from parler_tts import ParlerTTSForConditionalGeneration
from pydub import AudioSegment
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
from huggingface_hub import InferenceClient
import nltk
import random
nltk.download('punkt')


device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32

repo_id = "parler-tts/parler_tts_mini_v0.1"

jenny_repo_id = "ylacombe/parler-tts-mini-jenny-30H"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
).to(device)

client = InferenceClient()

description_tokenizer = AutoTokenizer.from_pretrained(repo_id) 
prompt_tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left")
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)

SAMPLE_RATE = feature_extractor.sampling_rate
SEED = 42


def numpy_to_mp3(audio_array, sampling_rate):
    # Normalize audio_array if it's floating-point
    if np.issubdtype(audio_array.dtype, np.floating):
        max_val = np.max(np.abs(audio_array)) + 1
        audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
        audio_array = audio_array.astype(np.int16)

    # Create an audio segment from the numpy array
    audio_segment = AudioSegment(
        audio_array.tobytes(),
        frame_rate=sampling_rate,
        sample_width=audio_array.dtype.itemsize,
        channels=1
    )

    # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
    mp3_io = io.BytesIO()
    audio_segment.export(mp3_io, format="mp3", bitrate="320k")

    # Get the MP3 bytes
    mp3_bytes = mp3_io.getvalue()
    mp3_io.close()

    return mp3_bytes

sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate


def generate_story(subject: str, setting: str) -> str:
    messages = [{"role": "sytem", "content": ("You are an award-winning children's bedtime story author lauded for your inventive stories."
                                              "You want to write a bed time story for your child. They will give you the subject and setting "
                                              "and you will write the entire story. It should be targetted at children 5 and younger and take about "
                                              "a minute to read")},
                {"role": "user", "content": f"Please tell me a story about a {subject} in {setting}"}]
    response = client.chat_completion(messages, max_tokens=1024, seed=random.randint(1, 5000))
    gr.Info("Story Generated", duration=3)
    story = response.choices[0].message.content
    return None, None, story


@spaces.GPU(duration=120)
def generate_base(story):


    model_input = story.replace("\n", " ").strip()
    model_input_tokens = nltk.sent_tokenize(model_input)

    play_steps_in_s = 4.0
    play_steps = int(frame_rate * play_steps_in_s)

    gr.Info("Generating Audio", duration=3)
    description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
    story_tokens = prompt_tokenizer(model_input_tokens, return_tensors="pt", padding=True).to(device)
    description_tokens = description_tokenizer([description for _ in range(len(model_input_tokens))], return_tensors="pt").to(device)
    speech_output = model.generate(input_ids=description_tokens.input_ids,
                                   prompt_input_ids=story_tokens.input_ids,
                                   attention_mask=description_tokens.attention_mask,
                                   prompt_attention_mask=story_tokens.attention_mask,
                                  return_dict_in_generate=True,
                                  )
    speech_output = [output.cpu().numpy()[:output_length] for (output, output_length) in zip(speech_output.sequences, speech_output.audios_length)]
    return None, None, speech_output


def stream_audio(hidden_story, speech_output):

    gr.Info("Reading Story")

    for new_audio in speech_output:
        print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
        yield hidden_story, numpy_to_mp3(new_audio, sampling_rate=sampling_rate)


with gr.Blocks() as block:
    gr.HTML(
        f"""
        <h1> Bedtime Story Reader 😴🔊 </h1>
        <p> Powered by <a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a>
        """
    )
    with gr.Group():
        with gr.Row():
            subject = gr.Dropdown(value="Princess", choices=["Prince", "Princess", "Dog", "Cat"], label="Subject")
            setting = gr.Dropdown(value="Forest", choices=["Forest", "Kingdom", "Jungle", "Underwater", "Pirate Ship"], label="Setting")
        with gr.Row():
            run_button = gr.Button("Generate Story", variant="primary")
    with gr.Row():
        with gr.Group():
            audio_out = gr.Audio(label="Bed time story",  streaming=True, autoplay=True)
            story = gr.Textbox(label="Story")

    inputs = [subject, setting]
    outputs = [story, audio_out]
    state = gr.State()
    hidden_story = gr.State()
    run_button.click(generate_story, inputs=inputs, outputs=[story, audio_out, hidden_story]).success(fn=generate_base, inputs=hidden_story, outputs=[story, audio_out, state]).success(stream_audio, inputs=[hidden_story, state], outputs=[story, audio_out])

block.queue()
block.launch(share=True)