sethuiyer commited on
Commit
a1056d4
1 Parent(s): 89ef7e1

Update entropic_cot.py

Browse files
Files changed (1) hide show
  1. entropic_cot.py +241 -110
entropic_cot.py CHANGED
@@ -1,114 +1,245 @@
1
  import torch
2
  import torch.nn.functional as F
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from typing import List, Dict, Tuple
5
-
6
- def cot_decode_speculative(
7
- model: AutoModelForCausalLM,
8
- tokenizer: AutoTokenizer,
9
- messages: List[Dict[str, str]],
10
- k: int = 3,
11
- max_new_tokens: int = 512
12
- ) -> Tuple[str, float]:
13
- """
14
- Generates text using a speculative decoding approach with confidence and entropy-based metrics.
15
-
16
- This function implements a Chain-of-Thought (CoT) decoding strategy that explores multiple
17
- potential next tokens and selects the one leading to a generation with lower entropy and
18
- higher confidence. It incorporates top-p sampling and calculates a path score based on confidence,
19
- entropy difference, and generation length.
20
-
21
- Args:
22
- model: The pre-trained language model.
23
- tokenizer: The corresponding tokenizer.
24
- messages: A list of dictionaries, where each dictionary represents a message with "role" and "content" keys.
25
- k: The number of top tokens to consider for speculative decoding.
26
- max_new_tokens: The maximum number of tokens to generate.
27
-
28
- Returns:
29
- A tuple containing the generated text and the calculated path score.
30
- """
31
-
32
- # Format the input based on tokenizer capabilities. Handles both chat template and standard formats.
33
- if hasattr(tokenizer, 'chat_template'): # Efficiently uses chat template if available.
34
- input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
35
- else: # Fallback for tokenizers without chat template support.
36
- input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
37
- input_text += "\nassistant:"
38
-
39
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda") # GPU usage specified
40
-
41
- # Handle missing pad_token_id, common in some models.
42
- if tokenizer.pad_token_id is None:
43
- tokenizer.pad_token_id = tokenizer.eos_token_id
44
- attention_mask = (input_ids != tokenizer.pad_token_id).long().to("cuda")
45
-
46
-
47
- with torch.no_grad(): # No gradient calculation needed for inference.
48
- outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) # Use caching for efficiency
49
- past_key_values = outputs.past_key_values # Store past key/values for faster generation
50
- first_token_logits = F.softmax(outputs.logits[0, -1, :], dim=-1) # Probabilities of the next token
51
- top_k_logits, top_k_indices = torch.topk(first_token_logits, k) # Get top k logits and indices
52
- cumulative_probs = torch.cumsum(top_k_logits, dim=0) # Calculate cumulative probabilities for top-p
53
- top_p_mask = cumulative_probs <= 0.9 # Apply top-p filtering (nucleus sampling)
54
- if not torch.any(top_p_mask): # Ensure at least one token is selected
55
- top_p_mask[0] = True
56
- top_k_logits, top_k_indices = top_k_logits[top_p_mask], top_k_indices[top_p_mask] # Filter based on top-p
57
-
58
-
59
- min_diff, best_idx, past_key_values = float('inf'), None, None # Initialize variables for best token selection
60
- new_attention_mask = torch.cat([attention_mask, torch.ones(1, 1).long().to("cuda")], dim=-1) # Extend attention mask
61
-
62
-
63
- # Speculative decoding: Evaluate top-k tokens
64
- for idx in top_k_indices: # Iterate through top-k candidate tokens.
65
- new_token = idx.unsqueeze(0).unsqueeze(0) # Prepare the token for generation
66
- new_tokens = torch.cat([input_ids, new_token], dim=-1) # Add the token to the input sequence
67
- with torch.no_grad():
68
- output = model.generate( # Generate one token to evaluate entropy
69
- new_tokens,
70
- attention_mask=new_attention_mask,
71
- max_new_tokens=1,
72
- output_scores=True,
73
- output_attentions=True, # Needed for attention entropy calculation
74
- return_dict_in_generate=True,
75
- past_key_values=past_key_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
- all_attentions = output.attentions[0][-1] # Extract last layer's attention weights
78
- attn_probs = F.softmax(all_attentions[:, -1, :], dim=-1) # Calculate attention probabilities
79
- entropy = -torch.sum(attn_probs * torch.log2(attn_probs + 1e-12), dim=-1) # Calculate entropy
80
- avg_entropy, avg_varentropy = torch.mean(entropy), torch.var(entropy) # Compute mean and variance of entropy
81
- diff = avg_entropy * 0.8 + avg_varentropy * 0.2 # Combine entropy metrics (weighted average)
82
-
83
- if diff < min_diff: # Select token with lowest entropy difference
84
- min_diff, best_idx = diff, idx
85
-
86
- new_token = best_idx.unsqueeze(0).unsqueeze(0) # Prepare the chosen best token.
87
- new_tokens = torch.cat([input_ids, new_token], dim=-1) # Append the token to the input sequence.
88
-
89
- # Generate the full sequence with the chosen best first token.
90
- output = model.generate(
91
- new_tokens,
92
- attention_mask=new_attention_mask,
93
- max_new_tokens=max_new_tokens,
94
- output_scores=True,
95
- return_dict_in_generate=True,
96
- past_key_values=past_key_values
97
- )
98
-
99
- answer_ids = output.sequences[0][len(input_ids[0]):] # Extract generated tokens
100
- answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) # Decode to text
101
-
102
- sum_confidence = 0
103
- for step in range(len(output.scores)):
104
- logits_step = output.scores[step][0] # Logits for current step
105
- probs_step = F.softmax(logits_step, dim=-1) # Probabilities for current step
106
- top_probs, _ = torch.topk(probs_step, k=2, dim=-1) # Get top 2 probabilities
107
- confidence = top_probs[0] - top_probs[1] # Calculate confidence as difference between top 2
108
- sum_confidence += confidence # Accumulate confidence over all steps
109
- avg_confidence = sum_confidence / len(answer_ids) # Average confidence per token
110
- avg_confidence = avg_confidence - 0.2 if avg_confidence >= 0.9 else avg_confidence # Adjust confidence if too high
111
-
112
-
113
- path_score = avg_confidence ** (min_diff) * (len(answer_ids) / max_new_tokens) # Calculate path score
114
- return answer_text, round(path_score.item() ** 0.33, 4) # Return generated text and score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from typing import List, Dict, Tuple, Optional, NamedTuple
5
+ from enum import Enum, auto
6
+ from dataclasses import dataclass
7
+ import warnings
8
+ warnings.filterwarnings("ignore", category=FutureWarning)
9
+
10
+
11
+ class DecoderState(Enum):
12
+ GREEDY_UNTIL_NEWLINE = auto()
13
+ SELECT_AFTER_NEWLINE = auto()
14
+ TERMINATED = auto()
15
+
16
+ class CacheState(NamedTuple):
17
+ past_key_values: Tuple
18
+ last_position: int
19
+
20
+ @dataclass
21
+ class GenerationState:
22
+ tokens: torch.Tensor
23
+ attention_mask: torch.Tensor
24
+ cache_state: Optional[CacheState] = None
25
+ entropy_diffs: List[float] = None
26
+ current_length: int = 0
27
+ _token_buffer: Optional[torch.Tensor] = None
28
+ _attn_buffer: Optional[torch.Tensor] = None
29
+
30
+ def __post_init__(self):
31
+ self.entropy_diffs = []
32
+ # Pre-allocate buffers for token and attention mask growth
33
+ max_length = self.tokens.size(1) + 1024 # reasonable buffer size
34
+ self._token_buffer = torch.zeros(
35
+ (1, max_length),
36
+ dtype=self.tokens.dtype,
37
+ device=self.tokens.device
38
+ )
39
+ self._attn_buffer = torch.ones(
40
+ (1, max_length),
41
+ dtype=self.attention_mask.dtype,
42
+ device=self.attention_mask.device
43
+ )
44
+ # Copy initial tokens and attention mask
45
+ self._token_buffer[:, :self.tokens.size(1)] = self.tokens
46
+ self._attn_buffer[:, :self.attention_mask.size(1)] = self.attention_mask
47
+
48
+ def extend(self, new_token: torch.Tensor):
49
+ """Efficient in-place extension of state"""
50
+ current_len = self.tokens.size(1)
51
+ if len(new_token.shape) == 0:
52
+ new_token = new_token.unsqueeze(0)
53
+
54
+ # Use pre-allocated buffers
55
+ self._token_buffer[:, current_len] = new_token
56
+ self.tokens = self._token_buffer[:, :current_len + 1]
57
+ self.attention_mask = self._attn_buffer[:, :current_len + 1]
58
+ self.current_length += 1
59
+
60
+ class SpeculativeDecoder:
61
+ def __init__(
62
+ self,
63
+ model: AutoModelForCausalLM,
64
+ tokenizer: AutoTokenizer,
65
+ device: Optional[torch.device] = None,
66
+ max_new_tokens: int = 512,
67
+ k: int = 3,
68
+ use_cache: bool = True
69
+ ):
70
+ self.model = model
71
+ self.tokenizer = tokenizer
72
+ self.device = device or next(model.parameters()).device
73
+ self.max_new_tokens = max_new_tokens
74
+ self.k = k
75
+ self.use_cache = use_cache
76
+
77
+ # Pre-compute constants
78
+ self.newline_token = tokenizer.encode("\n", add_special_tokens=False)[0]
79
+ if tokenizer.pad_token_id is None:
80
+ tokenizer.pad_token_id = tokenizer.eos_token_id
81
+
82
+ # Pre-allocate reusable tensors
83
+ self.batch_attention_mask = torch.ones(k, 1, dtype=torch.long, device=self.device)
84
+
85
+ # Prepare model for inference
86
+ if hasattr(model, 'eval'):
87
+ model.eval()
88
+
89
+ # Enable Flash Attention if available
90
+ if hasattr(model, 'enable_flash_attention'):
91
+ try:
92
+ model.enable_flash_attention()
93
+ except Exception as e:
94
+ warnings.warn(f"Failed to enable Flash Attention: {e}")
95
+
96
+ @staticmethod
97
+ @torch.jit.script
98
+ def calculate_entropy(probs: torch.Tensor) -> torch.Tensor:
99
+ """JIT-compiled entropy calculation"""
100
+ return -torch.sum(probs * torch.log2(probs + 1e-12), dim=-1)
101
+
102
+ def set_k(self, k: int):
103
+ self.k = k
104
+ self.batch_attention_mask = torch.ones(k, 1, dtype=torch.long, device=self.device)
105
+
106
+ def prepare_inputs(self, messages: List[Dict[str, str]]) -> torch.Tensor:
107
+ """Efficient input preparation"""
108
+ if hasattr(self.tokenizer, 'chat_template'):
109
+ input_text = self.tokenizer.apply_chat_template(
110
+ messages,
111
+ tokenize=False,
112
+ add_generation_prompt=True
113
  )
114
+ else:
115
+ input_text = "\n".join(f"{msg['role']}: {msg['content']}" for msg in messages) + "\nassistant:"
116
+
117
+ return self.tokenizer(
118
+ input_text,
119
+ return_tensors="pt",
120
+ padding=False
121
+ ).input_ids.to(self.device)
122
+
123
+ def select_least_entropic_token(self, state: GenerationState) -> Tuple[torch.Tensor, float]:
124
+ """Optimized token selection with vectorized operations"""
125
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
126
+ # Initial logits computation
127
+ outputs = self.model(
128
+ input_ids=state.tokens[:, -1:] if state.cache_state else state.tokens,
129
+ attention_mask=state.attention_mask,
130
+ past_key_values=state.cache_state.past_key_values if state.cache_state else None,
131
+ use_cache=True
132
+ )
133
+
134
+ state.cache_state = CacheState(outputs.past_key_values, state.tokens.size(1)) if self.use_cache else None
135
+
136
+ # Efficient top-k selection
137
+ logits = outputs.logits[0, -1]
138
+ top_k_probs, top_k_indices = torch.topk(F.softmax(logits, dim=-1), self.k)
139
+
140
+ # Prepare batch inputs
141
+ batch_tokens = top_k_indices.unsqueeze(1)
142
+
143
+ # Efficient cache expansion
144
+ if state.cache_state:
145
+ batch_past_kv = [
146
+ (
147
+ layer_past[0].expand(self.k, -1, -1, -1),
148
+ layer_past[1].expand(self.k, -1, -1, -1)
149
+ )
150
+ for layer_past in state.cache_state.past_key_values
151
+ ]
152
+ else:
153
+ batch_past_kv = None
154
+
155
+ # Single forward pass for all candidates
156
+ batch_outputs = self.model(
157
+ input_ids=batch_tokens,
158
+ attention_mask=self.batch_attention_mask,
159
+ past_key_values=batch_past_kv,
160
+ use_cache=True,
161
+ output_attentions=True
162
+ )
163
+
164
+ # Efficient attention processing
165
+ middle_layer = len(batch_outputs.attentions) // 2
166
+ batch_attn_probs = F.softmax(
167
+ batch_outputs.attentions[middle_layer][:, :, -1, :],
168
+ dim=-1
169
+ )
170
+
171
+ # Vectorized entropy calculation
172
+ old_entropy = self.calculate_entropy(batch_attn_probs[:, :, :-1])
173
+ new_entropy = self.calculate_entropy(batch_attn_probs)
174
+
175
+ # Efficient difference calculation
176
+ entropy_var = torch.var(
177
+ torch.stack([old_entropy, new_entropy]),
178
+ dim=0,
179
+ keepdim=True
180
+ ) + 1e-6
181
+ diffs = ((new_entropy - old_entropy) / entropy_var).mean(dim=-1).squeeze(0)
182
+ min_idx = diffs.argmin()
183
+ return top_k_indices[min_idx].unsqueeze(0), diffs[min_idx].item()
184
+
185
+ def greedy_decode(self, state: GenerationState) -> torch.Tensor:
186
+ """Optimized greedy decoding"""
187
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
188
+ outputs = self.model(
189
+ input_ids=state.tokens[:, -1:] if state.cache_state else state.tokens,
190
+ attention_mask=state.attention_mask,
191
+ past_key_values=state.cache_state.past_key_values if state.cache_state else None,
192
+ use_cache=True
193
+ )
194
+
195
+ state.cache_state = CacheState(
196
+ outputs.past_key_values,
197
+ state.tokens.size(1)
198
+ ) if self.use_cache else None
199
+
200
+ return outputs.logits[0, -1].argmax()
201
+
202
+ def __call__(
203
+ self,
204
+ messages: List[Dict[str, str]]
205
+ ) -> Tuple[str, float]:
206
+ """Main decoding loop with optimized state transitions"""
207
+ input_ids = self.prepare_inputs(messages)
208
+
209
+ state = GenerationState(
210
+ tokens=input_ids,
211
+ attention_mask=torch.ones_like(input_ids)
212
+ )
213
+
214
+ current_state = DecoderState.SELECT_AFTER_NEWLINE
215
+
216
+ while current_state != DecoderState.TERMINATED and state.current_length < self.max_new_tokens:
217
+ if current_state == DecoderState.SELECT_AFTER_NEWLINE:
218
+ next_token, entropy_diff = self.select_least_entropic_token(state)
219
+ state.entropy_diffs.append(entropy_diff)
220
+ current_state = DecoderState.GREEDY_UNTIL_NEWLINE
221
+
222
+ else: # GREEDY_UNTIL_NEWLINE
223
+ next_token = self.greedy_decode(state)
224
+
225
+ if next_token.item() == self.tokenizer.eos_token_id:
226
+ current_state = DecoderState.TERMINATED
227
+ elif next_token.item() == self.newline_token:
228
+ current_state = DecoderState.SELECT_AFTER_NEWLINE
229
+
230
+ state.extend(next_token)
231
+
232
+ # Efficient post-processing
233
+ generated_ids = state.tokens[0, len(input_ids[0]):]
234
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
235
+
236
+ # Vectorized score calculation
237
+ if state.entropy_diffs:
238
+ avg_entropy_diff = torch.tensor(state.entropy_diffs).mean().item()
239
+ else:
240
+ avg_entropy_diff = 1.0
241
+
242
+ completion_ratio = len(generated_ids) / self.max_new_tokens
243
+ score = (1.0 / (avg_entropy_diff/100 + 1e-12)) * completion_ratio
244
+
245
+ return generated_text, round(score ** 0.33, 4)