Crystalcareai commited on
Commit
d3c4ad0
1 Parent(s): 866d907

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +14 -107
inference.py CHANGED
@@ -1,121 +1,28 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
 
4
- def compute_memory_used_pct(device):
5
- memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
6
- memory_pct = (
7
- memory_used
8
- / (torch.cuda.get_device_properties(device).total_memory / (1024**3))
9
- * 100
10
- )
11
- return memory_pct
12
-
13
- model_path = "./out"
14
-
15
- n_ahead = 8
16
- n_ahead_talk = 4
17
- merged_talk_heads = True
18
-
19
- # Load the model
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_path,
22
- max_thoughts=n_ahead + n_ahead_talk + 1,
23
- merged_talk_heads=merged_talk_heads,
24
- merged_lm_and_talk_heads=False,
25
- merged_lm_and_think_heads=True,
26
- use_concat_talk_head=True,
27
- use_shallow_think=True,
28
- use_shallow_talk=False,
29
- use_complex_think_head=False,
30
- use_complex_talk_head=True,
31
- use_weighted_talk_head=True,
32
- trust_remote_code=True,
33
- torch_dtype=torch.bfloat16,
34
  device_map="auto",
 
 
 
35
  )
36
 
37
- # Load the tokenizer and assign it to the model instance for compatibility
38
  tokenizer = AutoTokenizer.from_pretrained(model_path)
39
- model.tokenizer = tokenizer
40
-
41
- model.use_end_thought_token = True
42
- model.use_start_thought_token = True
43
- model.wandb_enabled = True
44
- model.n_ahead = n_ahead
45
- model.n_passes = 2
46
- model.eval_mode = True
47
- model.first_run = False
48
- model.kill_after = 100
49
- model.rm_initialized = True
50
- model.original_mode = False
51
-
52
- # Custom generate function
53
- def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
54
- with torch.no_grad():
55
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
56
- for cur_token_idx in range(max_new_tokens):
57
- # Sample the next token
58
- new_ids = model(
59
- input_ids[~finished_generating],
60
- attention_mask=attention_mask[~finished_generating]
61
- )['logits']
62
- # Mask out the start and end thought tokens so we don't accidentally sample them
63
- new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
64
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
65
- # Find the index of the last token that is not padding
66
- base_answer_ids = input_ids[answer_idx]
67
- new_answer_ids = new_ids[list_idx]
68
- last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
69
 
70
- new_ids_sampled = torch.multinomial(
71
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
72
- # Assign the new id to the last token
73
- if last_token_idx + 1 >= len(base_answer_ids):
74
- # Add padding everywhere
75
- new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
76
- device=input_ids.device)
77
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
78
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
79
- attention_mask[answer_idx, last_token_idx + 1] = 1
80
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
81
- if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
82
- finished_generating[answer_idx] = 1
83
- # Check if the end token is generated
84
- if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
85
- finished_generating[answer_idx] = 1
86
- if finished_generating.all():
87
- break
88
- streamer.put(new_ids_sampled)
89
- return input_ids, attention_mask
90
 
91
- # Formulate your prompt
92
- prompt_template = "[INST] {prompt} [/INST]"
 
 
93
 
94
- prompt = "You're standing on the surface of the Earth. "\
95
- "You walk one mile south, one mile west and one mile north. "\
96
- "You end up exactly where you started. Where are you?"
97
-
98
- # Convert prompt to tokens
99
- tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device)
100
-
101
- # Generate an attention mask
102
- attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device)
103
-
104
- streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
105
-
106
- # Generate output using the custom generate function
107
- output_ids, _ = custom_generate(
108
- model,
109
- input_ids=tokens,
110
- attention_mask=attention_mask,
111
- max_new_tokens=512,
112
  streamer=streamer,
113
- temperature=0.9,
114
  )
115
 
116
- generated_text = ""
117
-
118
- print() # Print a newline after streaming is complete
119
-
120
- # Cleanup if necessary
121
- torch.cuda.empty_cache()
 
1
  import torch
2
+ from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
3
 
4
+ model_path = "cognitivecomputations/Quiet-STaR-Base"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  model = AutoModelForCausalLM.from_pretrained(
6
  model_path,
 
 
 
 
 
 
 
 
 
 
 
 
7
  device_map="auto",
8
+ low_cpu_mem_usage=True,
9
+ torch_dtype=torch.bfloat16,
10
+ trust_remote_code=True,
11
  )
12
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_path)
14
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ prompt = "Hello my name is"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ tokens = tokenizer(
19
+ prompt,
20
+ return_tensors='pt'
21
+ ).input_ids.cuda()
22
 
23
+ generation_output = model.generate(
24
+ tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  streamer=streamer,
26
+ max_new_tokens=512,
27
  )
28