Spaces:
Running
on
Zero
Running
on
Zero
da03
commited on
Commit
•
0ad2aca
1
Parent(s):
3e9942b
app.py
CHANGED
@@ -37,11 +37,18 @@ def predict_product(num1, num2):
|
|
37 |
eos_token_id = tokenizer.eos_token_id
|
38 |
past_key_values = None
|
39 |
for _ in range(100): # Set a maximum limit to prevent infinite loops
|
40 |
-
outputs = model
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
past_key_values = outputs.past_key_values
|
43 |
|
44 |
-
|
|
|
|
|
|
|
45 |
break
|
46 |
|
47 |
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
|
37 |
eos_token_id = tokenizer.eos_token_id
|
38 |
past_key_values = None
|
39 |
for _ in range(100): # Set a maximum limit to prevent infinite loops
|
40 |
+
outputs = model(
|
41 |
+
input_ids=generated_ids,
|
42 |
+
past_key_values=past_key_values,
|
43 |
+
use_cache=True
|
44 |
+
)
|
45 |
+
logits = outputs.logits
|
46 |
past_key_values = outputs.past_key_values
|
47 |
|
48 |
+
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
|
49 |
+
generated_ids = torch.cat((generated_ids, next_token_id.unsqueeze(-1)), dim=-1)
|
50 |
+
|
51 |
+
if next_token_id.item() == eos_token_id:
|
52 |
break
|
53 |
|
54 |
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|