taka-yamakoshi commited on
Commit
9149baa
·
1 Parent(s): 601d754
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -56,5 +56,5 @@ if __name__=='__main__':
56
  outputs = model(input_ids)
57
  logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
58
  st.write(logprobs.shape)
59
- preds = [np.random.choice(jnp.arange(len(probs)),p=jnp.exp(probs)/jnp.sum(jnp.exp(probs))) for probs in logprobs[0]]
60
  st.write([tokenizer.decode([token]) for token in preds])
 
56
  outputs = model(input_ids)
57
  logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
58
  st.write(logprobs.shape)
59
+ preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
60
  st.write([tokenizer.decode([token]) for token in preds])