Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ from threading import Thread
|
|
11 |
import torch
|
12 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
13 |
|
14 |
-
|
15 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
16 |
|
17 |
|
@@ -34,8 +34,8 @@ def predict(message, history, system_prompt, temperature, max_tokens):
|
|
34 |
input_ids = enc.input_ids
|
35 |
attention_mask = enc.attention_mask
|
36 |
|
37 |
-
if input_ids.shape[1] >
|
38 |
-
input_ids = input_ids[:, -
|
39 |
|
40 |
input_ids = input_ids.to(device)
|
41 |
attention_mask = attention_mask.to(device)
|
|
|
11 |
import torch
|
12 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
13 |
|
14 |
+
MAX_LENGTH = 4096
|
15 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
16 |
|
17 |
|
|
|
34 |
input_ids = enc.input_ids
|
35 |
attention_mask = enc.attention_mask
|
36 |
|
37 |
+
if input_ids.shape[1] > MAX_LENGTH:
|
38 |
+
input_ids = input_ids[:, -MAX_LENGTH:]
|
39 |
|
40 |
input_ids = input_ids.to(device)
|
41 |
attention_mask = attention_mask.to(device)
|