da03 commited on
Commit
0ad2aca
1 Parent(s): 3e9942b
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -37,11 +37,18 @@ def predict_product(num1, num2):
37
  eos_token_id = tokenizer.eos_token_id
38
  past_key_values = None
39
  for _ in range(100): # Set a maximum limit to prevent infinite loops
40
- outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False, past_key_values=past_key_values, return_dict_in_generate=True, use_cache=True)
41
- generated_ids = torch.cat((generated_ids, outputs.sequences[:, -1:]), dim=-1)
 
 
 
 
42
  past_key_values = outputs.past_key_values
43
 
44
- if outputs.sequences[0, -1].item() == eos_token_id:
 
 
 
45
  break
46
 
47
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
37
  eos_token_id = tokenizer.eos_token_id
38
  past_key_values = None
39
  for _ in range(100): # Set a maximum limit to prevent infinite loops
40
+ outputs = model(
41
+ input_ids=generated_ids,
42
+ past_key_values=past_key_values,
43
+ use_cache=True
44
+ )
45
+ logits = outputs.logits
46
  past_key_values = outputs.past_key_values
47
 
48
+ next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
49
+ generated_ids = torch.cat((generated_ids, next_token_id.unsqueeze(-1)), dim=-1)
50
+
51
+ if next_token_id.item() == eos_token_id:
52
  break
53
 
54
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)