File size: 657 Bytes
705c81f d3c4ad0 2f4f471 d3c4ad0 495a5d0 d3c4ad0 495a5d0 2f4f471 d3c4ad0 495a5d0 d3c4ad0 495a5d0 d3c4ad0 705c81f d3c4ad0 495a5d0 d3c4ad0 495a5d0 0399b0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
model_path = "cognitivecomputations/Quiet-STaR-Base"
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt = "Hello my name is"
tokens = tokenizer(
prompt,
return_tensors='pt'
).input_ids.cuda()
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512,
)
|