File size: 657 Bytes
705c81f
d3c4ad0
2f4f471
d3c4ad0
495a5d0
 
 
d3c4ad0
 
 
495a5d0
2f4f471
 
d3c4ad0
495a5d0
d3c4ad0
495a5d0
d3c4ad0
 
 
 
705c81f
d3c4ad0
 
495a5d0
d3c4ad0
495a5d0
0399b0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM

model_path = "cognitivecomputations/Quiet-STaR-Base"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

prompt = "Hello my name is"

tokens = tokenizer(
    prompt, 
    return_tensors='pt'
).input_ids.cuda()

generation_output = model.generate(
    tokens, 
    streamer=streamer,
    max_new_tokens=512,
)