Spaces:
Running
on
A10G
Running
on
A10G
Jonathan Fly
commited on
Commit
•
8764625
1
Parent(s):
d56cc00
Use greedy sampling path when temp is 0.0 to avoid division by zero (#53)
Browse files- audiocraft/models/lm.py +2 -1
audiocraft/models/lm.py
CHANGED
@@ -363,7 +363,8 @@ class LMModel(StreamingModule):
|
|
363 |
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
364 |
logits = logits[..., -1] # [B x K x card]
|
365 |
|
366 |
-
if
|
|
|
367 |
probs = torch.softmax(logits / temp, dim=-1)
|
368 |
if top_p > 0.0:
|
369 |
next_token = utils.sample_top_p(probs, p=top_p)
|
|
|
363 |
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
364 |
logits = logits[..., -1] # [B x K x card]
|
365 |
|
366 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
367 |
+
if use_sampling and temp > 0.0:
|
368 |
probs = torch.softmax(logits / temp, dim=-1)
|
369 |
if top_p > 0.0:
|
370 |
next_token = utils.sample_top_p(probs, p=top_p)
|