Create entropic_cot.py
Browse files- entropic_cot.py +114 -0
entropic_cot.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from typing import List, Dict, Tuple
|
5 |
+
|
6 |
+
def cot_decode_speculative(
|
7 |
+
model: AutoModelForCausalLM,
|
8 |
+
tokenizer: AutoTokenizer,
|
9 |
+
messages: List[Dict[str, str]],
|
10 |
+
k: int = 3,
|
11 |
+
max_new_tokens: int = 512
|
12 |
+
) -> Tuple[str, float]:
|
13 |
+
"""
|
14 |
+
Generates text using a speculative decoding approach with confidence and entropy-based metrics.
|
15 |
+
|
16 |
+
This function implements a Chain-of-Thought (CoT) decoding strategy that explores multiple
|
17 |
+
potential next tokens and selects the one leading to a generation with lower entropy and
|
18 |
+
higher confidence. It incorporates top-p sampling and calculates a path score based on confidence,
|
19 |
+
entropy difference, and generation length.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model: The pre-trained language model.
|
23 |
+
tokenizer: The corresponding tokenizer.
|
24 |
+
messages: A list of dictionaries, where each dictionary represents a message with "role" and "content" keys.
|
25 |
+
k: The number of top tokens to consider for speculative decoding.
|
26 |
+
max_new_tokens: The maximum number of tokens to generate.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
A tuple containing the generated text and the calculated path score.
|
30 |
+
"""
|
31 |
+
|
32 |
+
# Format the input based on tokenizer capabilities. Handles both chat template and standard formats.
|
33 |
+
if hasattr(tokenizer, 'chat_template'): # Efficiently uses chat template if available.
|
34 |
+
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
35 |
+
else: # Fallback for tokenizers without chat template support.
|
36 |
+
input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
37 |
+
input_text += "\nassistant:"
|
38 |
+
|
39 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda") # GPU usage specified
|
40 |
+
|
41 |
+
# Handle missing pad_token_id, common in some models.
|
42 |
+
if tokenizer.pad_token_id is None:
|
43 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
44 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to("cuda")
|
45 |
+
|
46 |
+
|
47 |
+
with torch.no_grad(): # No gradient calculation needed for inference.
|
48 |
+
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) # Use caching for efficiency
|
49 |
+
past_key_values = outputs.past_key_values # Store past key/values for faster generation
|
50 |
+
first_token_logits = F.softmax(outputs.logits[0, -1, :], dim=-1) # Probabilities of the next token
|
51 |
+
top_k_logits, top_k_indices = torch.topk(first_token_logits, k) # Get top k logits and indices
|
52 |
+
cumulative_probs = torch.cumsum(top_k_logits, dim=0) # Calculate cumulative probabilities for top-p
|
53 |
+
top_p_mask = cumulative_probs <= 0.9 # Apply top-p filtering (nucleus sampling)
|
54 |
+
if not torch.any(top_p_mask): # Ensure at least one token is selected
|
55 |
+
top_p_mask[0] = True
|
56 |
+
top_k_logits, top_k_indices = top_k_logits[top_p_mask], top_k_indices[top_p_mask] # Filter based on top-p
|
57 |
+
|
58 |
+
|
59 |
+
min_diff, best_idx, past_key_values = float('inf'), None, None # Initialize variables for best token selection
|
60 |
+
new_attention_mask = torch.cat([attention_mask, torch.ones(1, 1).long().to("cuda")], dim=-1) # Extend attention mask
|
61 |
+
|
62 |
+
|
63 |
+
# Speculative decoding: Evaluate top-k tokens
|
64 |
+
for idx in top_k_indices: # Iterate through top-k candidate tokens.
|
65 |
+
new_token = idx.unsqueeze(0).unsqueeze(0) # Prepare the token for generation
|
66 |
+
new_tokens = torch.cat([input_ids, new_token], dim=-1) # Add the token to the input sequence
|
67 |
+
with torch.no_grad():
|
68 |
+
output = model.generate( # Generate one token to evaluate entropy
|
69 |
+
new_tokens,
|
70 |
+
attention_mask=new_attention_mask,
|
71 |
+
max_new_tokens=1,
|
72 |
+
output_scores=True,
|
73 |
+
output_attentions=True, # Needed for attention entropy calculation
|
74 |
+
return_dict_in_generate=True,
|
75 |
+
past_key_values=past_key_values
|
76 |
+
)
|
77 |
+
all_attentions = output.attentions[0][-1] # Extract last layer's attention weights
|
78 |
+
attn_probs = F.softmax(all_attentions[:, -1, :], dim=-1) # Calculate attention probabilities
|
79 |
+
entropy = -torch.sum(attn_probs * torch.log2(attn_probs + 1e-12), dim=-1) # Calculate entropy
|
80 |
+
avg_entropy, avg_varentropy = torch.mean(entropy), torch.var(entropy) # Compute mean and variance of entropy
|
81 |
+
diff = avg_entropy * 0.8 + avg_varentropy * 0.2 # Combine entropy metrics (weighted average)
|
82 |
+
|
83 |
+
if diff < min_diff: # Select token with lowest entropy difference
|
84 |
+
min_diff, best_idx = diff, idx
|
85 |
+
|
86 |
+
new_token = best_idx.unsqueeze(0).unsqueeze(0) # Prepare the chosen best token.
|
87 |
+
new_tokens = torch.cat([input_ids, new_token], dim=-1) # Append the token to the input sequence.
|
88 |
+
|
89 |
+
# Generate the full sequence with the chosen best first token.
|
90 |
+
output = model.generate(
|
91 |
+
new_tokens,
|
92 |
+
attention_mask=new_attention_mask,
|
93 |
+
max_new_tokens=max_new_tokens,
|
94 |
+
output_scores=True,
|
95 |
+
return_dict_in_generate=True,
|
96 |
+
past_key_values=past_key_values
|
97 |
+
)
|
98 |
+
|
99 |
+
answer_ids = output.sequences[0][len(input_ids[0]):] # Extract generated tokens
|
100 |
+
answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) # Decode to text
|
101 |
+
|
102 |
+
sum_confidence = 0
|
103 |
+
for step in range(len(output.scores)):
|
104 |
+
logits_step = output.scores[step][0] # Logits for current step
|
105 |
+
probs_step = F.softmax(logits_step, dim=-1) # Probabilities for current step
|
106 |
+
top_probs, _ = torch.topk(probs_step, k=2, dim=-1) # Get top 2 probabilities
|
107 |
+
confidence = top_probs[0] - top_probs[1] # Calculate confidence as difference between top 2
|
108 |
+
sum_confidence += confidence # Accumulate confidence over all steps
|
109 |
+
avg_confidence = sum_confidence / len(answer_ids) # Average confidence per token
|
110 |
+
avg_confidence = avg_confidence - 0.2 if avg_confidence >= 0.9 else avg_confidence # Adjust confidence if too high
|
111 |
+
|
112 |
+
|
113 |
+
path_score = avg_confidence ** (min_diff) * (len(answer_ids) / max_new_tokens) # Calculate path score
|
114 |
+
return answer_text, round(path_score.item() ** 0.33, 4) # Return generated text and score
|