da03 commited on
Commit
3e9942b
1 Parent(s): f7dc2d2
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -35,11 +35,13 @@ def predict_product(num1, num2):
35
  valid_input = False
36
 
37
  eos_token_id = tokenizer.eos_token_id
 
38
  for _ in range(100): # Set a maximum limit to prevent infinite loops
39
- outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False)
40
- generated_ids = torch.cat((generated_ids, outputs[:, -1:]), dim=-1)
 
41
 
42
- if outputs[0, -1].item() == eos_token_id:
43
  break
44
 
45
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
35
  valid_input = False
36
 
37
  eos_token_id = tokenizer.eos_token_id
38
+ past_key_values = None
39
  for _ in range(100): # Set a maximum limit to prevent infinite loops
40
+ outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False, past_key_values=past_key_values, return_dict_in_generate=True, use_cache=True)
41
+ generated_ids = torch.cat((generated_ids, outputs.sequences[:, -1:]), dim=-1)
42
+ past_key_values = outputs.past_key_values
43
 
44
+ if outputs.sequences[0, -1].item() == eos_token_id:
45
  break
46
 
47
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)