Spaces:
Build error
Build error
File size: 1,999 Bytes
74f2c64 50d21df 74f2c64 3cc2a8f 74f2c64 50d21df 74f2c64 3cc2a8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import pytest
import torch
import numpy as np
from src import predict, file_readers, config
import test_config
def test_load_model():
"""
Tests load_model function, which loads the silero TTS model.
"""
model = predict.load_model()
assert model.speakers[0] == 'en_0'
assert np.shape(model.speakers) == (119,)
def test_generate_audio():
"""
Tests generate_audio function, which takes the TTS model and file input,
and uses the predict & write_audio functions to output the audio file.
"""
ebook_path = test_config.data_path / "test.epub"
wav1_path = config.output_path / 'the_picture_of_dorian_gray_part000.wav'
wav2_path = config.output_path / 'the_picture_of_dorian_gray_part001.wav'
wav3_path = config.output_path / 'the_picture_of_dorian_gray_part002.wav'
corpus, title = file_readers.read_epub(ebook_path)
model = predict.load_model()
speaker = 'en_110'
predict.generate_audio(corpus[0:2], title, model, speaker)
assert wav1_path.is_file()
assert wav2_path.is_file()
assert not wav3_path.is_file()
wav1_path.unlink()
wav2_path.unlink()
def test_predict():
"""
Tests predict function, generates audio tensors for each token in the text section,
and appends them together along with a generated file path for output.
"""
seed = 1337
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
model = predict.load_model()
tensor_path = test_config.data_path / "test_predict.pt"
test_tensor = torch.load(tensor_path)
text_path = test_config.data_path / "test_predict.txt"
with open(text_path, 'r') as file:
text = file_readers.preprocess_text(file)
title = 'test_predict'
section_index = 'part001'
speaker = 'en_0'
audio_list, _ = predict.predict(text, section_index, title, model, speaker)
audio_tensor = torch.cat(audio_list).reshape(1, -1)
torch.testing.assert_close(audio_tensor, test_tensor, atol=1e-3, rtol=0.9)
|