Spaces:
Runtime error
Runtime error
Update tacotron.py
Browse files- tacotron.py +1 -0
tacotron.py
CHANGED
@@ -373,6 +373,7 @@ class Tacotron(pax.Module):
|
|
373 |
mel = x[..., :-1]
|
374 |
eos_logit = x[..., -1]
|
375 |
eos_pr = jax.nn.sigmoid(eos_logit[0, -1])
|
|
|
376 |
rng_key, eos_rng_key = jax.random.split(rng_key)
|
377 |
eos = jax.random.bernoulli(eos_rng_key, p=eos_pr)
|
378 |
return attn_state, decoder_rnn_states, rng_key, (mel, eos)
|
|
|
373 |
mel = x[..., :-1]
|
374 |
eos_logit = x[..., -1]
|
375 |
eos_pr = jax.nn.sigmoid(eos_logit[0, -1])
|
376 |
+
eos_pr = jnp.where(eos_pr < 0.1, 0.0, eos_pr)
|
377 |
rng_key, eos_rng_key = jax.random.split(rng_key)
|
378 |
eos = jax.random.bernoulli(eos_rng_key, p=eos_pr)
|
379 |
return attn_state, decoder_rnn_states, rng_key, (mel, eos)
|