[TODO] FP16 Inference
#4
pinned
by
hexgrad
- opened
Help wanted modifying the inference code to enable FP16 inference. Here are the steps taken so far:
- Very simple script halve.py cuts the model precision in half, from FP32 down to FP16. The new model is saved as
kokoro-v0_19-half.pth
and we know it was cut in half because the file size is halved from 320 MB to 160 MB. Quick maffs: 80M params, 4 => 2 bytes per param, yes it's supposed to be 320 => 160 MB. - Run the below cell in Colab to ensure the halved model at
fp16/kokoro-v0_19-half.pth
still works:
# 1️⃣ Install dependencies silently
!git clone https://huggingface.co/hexgrad/Kokoro-82M
%cd Kokoro-82M
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
!pip install -q phonemizer torch transformers scipy munch
# 2️⃣ Build the model and load the default voicepack
from models import build_model
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL = build_model('fp16/kokoro-v0_19-half.pth', device) # Half precision model upcast to fp32
VOICEPACK = torch.load('voices/af.pt', weights_only=True).to(device)
# 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes
from kokoro import generate
text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
audio, out_ps = generate(MODEL, text, VOICEPACK)
# 4️⃣ Display the 24khz audio and print the output phonemes
from IPython.display import display, Audio
display(Audio(data=audio, rate=24000, autoplay=True))
print(out_ps)
- Listen to the outputs:
fp32.wav
fp16.wav
- The current inference code implicitly upcasts the half precision model to FP32 before doing inference, so we're not actually gaining any inference speed (or memory footprint reduction, I think) using the FP16 precision model. You can verify this yourself using timing functions. This where you come in (maybe)?
Your mission, should you choose to accept it, is to modify the inference code to enable FP16 inference. Get the speedup, while keeping the audio output identical/similar.
- The inference code is deliberately slimmed down to make it easier to read the relevant pieces, relative to the entire StyleTTS2 repo.
- My previous attempts have failed and the outputs have bricked into noise or silence, and I haven't done much debugging. If/when I get a chance to clean and upload the failed code, I may put it under the
fp16
folder.
hexgrad
pinned discussion