|
import torch |
|
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM |
|
|
|
model_path = "cognitivecomputations/Quiet-STaR-Base" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="auto", |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
prompt = "Hello my name is" |
|
|
|
tokens = tokenizer( |
|
prompt, |
|
return_tensors='pt' |
|
).input_ids.cuda() |
|
|
|
generation_output = model.generate( |
|
tokens, |
|
streamer=streamer, |
|
max_new_tokens=512, |
|
) |
|
|
|
|