Spaces:
Runtime error
Runtime error
change device
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ def parse_args():
|
|
18 |
return parser.parse_args()
|
19 |
|
20 |
def predict(message, history, system_prompt, temperature, max_tokens):
|
21 |
-
global model, tokenizer
|
22 |
instruction = "<|im_start|>system\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|im_end|>\n"
|
23 |
for human, assistant in history:
|
24 |
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
|
@@ -33,8 +33,8 @@ def predict(message, history, system_prompt, temperature, max_tokens):
|
|
33 |
if input_ids.shape[1] > MAX_MAX_NEW_TOKENS:
|
34 |
input_ids = input_ids[:, -MAX_MAX_NEW_TOKENS:]
|
35 |
|
36 |
-
input_ids = input_ids.
|
37 |
-
attention_mask = attention_mask.
|
38 |
generate_kwargs = dict(
|
39 |
{"input_ids": input_ids, "attention_mask": attention_mask},
|
40 |
streamer=streamer,
|
@@ -59,7 +59,8 @@ if __name__ == "__main__":
|
|
59 |
args = parse_args()
|
60 |
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-instruct-3b")
|
61 |
model = AutoModelForCausalLM.from_pretrained("stabilityai/stable-code-instruct-3b")
|
62 |
-
|
|
|
63 |
gr.ChatInterface(
|
64 |
predict,
|
65 |
title="Stable Code Instruct Chat - Demo",
|
|
|
18 |
return parser.parse_args()
|
19 |
|
20 |
def predict(message, history, system_prompt, temperature, max_tokens):
|
21 |
+
global model, tokenizer, device
|
22 |
instruction = "<|im_start|>system\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|im_end|>\n"
|
23 |
for human, assistant in history:
|
24 |
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
|
|
|
33 |
if input_ids.shape[1] > MAX_MAX_NEW_TOKENS:
|
34 |
input_ids = input_ids[:, -MAX_MAX_NEW_TOKENS:]
|
35 |
|
36 |
+
input_ids = input_ids.to(device)
|
37 |
+
attention_mask = attention_mask.to(device)
|
38 |
generate_kwargs = dict(
|
39 |
{"input_ids": input_ids, "attention_mask": attention_mask},
|
40 |
streamer=streamer,
|
|
|
59 |
args = parse_args()
|
60 |
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-instruct-3b")
|
61 |
model = AutoModelForCausalLM.from_pretrained("stabilityai/stable-code-instruct-3b")
|
62 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
63 |
+
model = model.to(device)
|
64 |
gr.ChatInterface(
|
65 |
predict,
|
66 |
title="Stable Code Instruct Chat - Demo",
|