Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
9149baa
1
Parent(s):
601d754
fix
Browse files
app.py
CHANGED
@@ -56,5 +56,5 @@ if __name__=='__main__':
|
|
56 |
outputs = model(input_ids)
|
57 |
logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
|
58 |
st.write(logprobs.shape)
|
59 |
-
preds = [np.random.choice(
|
60 |
st.write([tokenizer.decode([token]) for token in preds])
|
|
|
56 |
outputs = model(input_ids)
|
57 |
logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
|
58 |
st.write(logprobs.shape)
|
59 |
+
preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
|
60 |
st.write([tokenizer.decode([token]) for token in preds])
|