Crystalcareai commited on
Commit
e5bb001
1 Parent(s): d3c4ad0

Create Inference-improved.py

Browse files
Files changed (1) hide show
  1. Inference-improved.py +108 -0
Inference-improved.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
+
4
+ model_path = "cognitivecompuations/Quiet-STaR-Base"
5
+
6
+ n_ahead = 8
7
+ n_ahead_talk = 4
8
+ merged_talk_heads = True
9
+
10
+ # Load the model
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_path,
13
+ max_thoughts=n_ahead + n_ahead_talk + 1,
14
+ merged_talk_heads=merged_talk_heads,
15
+ merged_lm_and_talk_heads=False,
16
+ merged_lm_and_think_heads=True,
17
+ use_concat_talk_head=True,
18
+ use_shallow_think=True,
19
+ use_shallow_talk=False,
20
+ use_complex_think_head=False,
21
+ use_complex_talk_head=True,
22
+ use_weighted_talk_head=True,
23
+ trust_remote_code=True,
24
+ torch_dtype=torch.bfloat16,
25
+ device_map="auto",
26
+ )
27
+
28
+ # Load the tokenizer and assign it to the model instance for compatibility
29
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
30
+ model.tokenizer = tokenizer
31
+
32
+ model.use_end_thought_token = True
33
+ model.use_start_thought_token = True
34
+ model.wandb_enabled = True
35
+ model.n_ahead = n_ahead
36
+ model.n_passes = 2
37
+ model.eval_mode = True
38
+ model.first_run = False
39
+ model.kill_after = 100
40
+ model.rm_initialized = True
41
+ model.original_mode = False
42
+
43
+ def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs):
44
+ with torch.no_grad():
45
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
46
+ for cur_token_idx in range(max_new_tokens):
47
+ # Sample the next token
48
+ new_ids = model(
49
+ input_ids[~finished_generating],
50
+ attention_mask=attention_mask[~finished_generating]
51
+ )['logits']
52
+ # Mask out the start and end thought tokens so we don't accidentally sample them
53
+ new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf")
54
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
55
+ # Find the index of the last token that is not padding
56
+ base_answer_ids = input_ids[answer_idx]
57
+ new_answer_ids = new_ids[list_idx]
58
+ last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
59
+
60
+ new_ids_sampled = torch.multinomial(
61
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
62
+ # Assign the new id to the last token
63
+ if last_token_idx + 1 >= len(base_answer_ids):
64
+ # Add padding everywhere
65
+ new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long,
66
+ device=input_ids.device)
67
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
68
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
69
+ attention_mask[answer_idx, last_token_idx + 1] = 1
70
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
71
+ if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
72
+ finished_generating[answer_idx] = 1
73
+ # Check if the end token is generated
74
+ if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
75
+ finished_generating[answer_idx] = 1
76
+ if finished_generating.all():
77
+ break
78
+ streamer.put(new_ids_sampled)
79
+ return input_ids, attention_mask
80
+
81
+ prompt = " How would a typical person answer each of the following questions about causation? Frank T., had an ongoing dispute with his neighbor over a stretch of land and one day decided to shoot his neighbor in the body. Frank T. had no experience with guns, his hand slipped on the barrel of the gun, and the shot went wild. Nonetheless, the bullet bounced off a large boulder several feet away and hit the neighbor's body, causing significant injury. Did Frank T. intentionally shoot his neighbor in the body?"
82
+
83
+ input_ids = tokenizer(
84
+ prompt=prompt,
85
+ return_tensors='pt'
86
+ ).input_ids.cuda()
87
+
88
+ # Convert prompt to tokens
89
+ tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device)
90
+
91
+ # Generate an attention mask
92
+ attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device)
93
+
94
+ streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
95
+
96
+ output_ids, _ = custom_generate(
97
+ model,
98
+ input_ids=tokens,
99
+ attention_mask=attention_mask,
100
+ max_new_tokens=512,
101
+ streamer=streamer,
102
+ temperature=0.9,
103
+ )
104
+
105
+ generated_text = ""
106
+
107
+ print()
108
+