BigSalmon commited on
Commit
136c22a
1 Parent(s): 33b82a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -8,7 +8,8 @@ from transformers.activations import get_activation
8
  from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM
9
 
10
  st.title('GPT2:')
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
12
 
13
  @st.cache(allow_output_mutation=True)
14
  def get_model():
@@ -63,11 +64,11 @@ with st.form(key='my_form'):
63
  text = tokenizer.encode(prompt)
64
  myinput, past_key_values = torch.tensor([text]), None
65
  myinput = myinput
66
- myinput= myinput.to(device)
67
  logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
68
  logits = logits[0,-1]
69
  probabilities = torch.nn.functional.softmax(logits)
70
- best_logits, best_indices = logits.topk(350)
71
  best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
72
  text.append(best_indices[0].item())
73
  best_probabilities = probabilities[best_indices].tolist()
 
8
  from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM
9
 
10
  st.title('GPT2:')
11
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ number_of_outputs = st.sidebar.slider("Number of Outputs", 50, 350)
13
 
14
  @st.cache(allow_output_mutation=True)
15
  def get_model():
 
64
  text = tokenizer.encode(prompt)
65
  myinput, past_key_values = torch.tensor([text]), None
66
  myinput = myinput
67
+ #myinput= myinput
68
  logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
69
  logits = logits[0,-1]
70
  probabilities = torch.nn.functional.softmax(logits)
71
+ best_logits, best_indices = logits.topk(number_of_outputs)
72
  best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
73
  text.append(best_indices[0].item())
74
  best_probabilities = probabilities[best_indices].tolist()