sethuiyer commited on
Commit
05adcf7
1 Parent(s): 8e3f0fb

Create entropic_cot.py

Browse files
Files changed (1) hide show
  1. entropic_cot.py +114 -0
entropic_cot.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from typing import List, Dict, Tuple
5
+
6
+ def cot_decode_speculative(
7
+ model: AutoModelForCausalLM,
8
+ tokenizer: AutoTokenizer,
9
+ messages: List[Dict[str, str]],
10
+ k: int = 3,
11
+ max_new_tokens: int = 512
12
+ ) -> Tuple[str, float]:
13
+ """
14
+ Generates text using a speculative decoding approach with confidence and entropy-based metrics.
15
+
16
+ This function implements a Chain-of-Thought (CoT) decoding strategy that explores multiple
17
+ potential next tokens and selects the one leading to a generation with lower entropy and
18
+ higher confidence. It incorporates top-p sampling and calculates a path score based on confidence,
19
+ entropy difference, and generation length.
20
+
21
+ Args:
22
+ model: The pre-trained language model.
23
+ tokenizer: The corresponding tokenizer.
24
+ messages: A list of dictionaries, where each dictionary represents a message with "role" and "content" keys.
25
+ k: The number of top tokens to consider for speculative decoding.
26
+ max_new_tokens: The maximum number of tokens to generate.
27
+
28
+ Returns:
29
+ A tuple containing the generated text and the calculated path score.
30
+ """
31
+
32
+ # Format the input based on tokenizer capabilities. Handles both chat template and standard formats.
33
+ if hasattr(tokenizer, 'chat_template'): # Efficiently uses chat template if available.
34
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
35
+ else: # Fallback for tokenizers without chat template support.
36
+ input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
37
+ input_text += "\nassistant:"
38
+
39
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda") # GPU usage specified
40
+
41
+ # Handle missing pad_token_id, common in some models.
42
+ if tokenizer.pad_token_id is None:
43
+ tokenizer.pad_token_id = tokenizer.eos_token_id
44
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to("cuda")
45
+
46
+
47
+ with torch.no_grad(): # No gradient calculation needed for inference.
48
+ outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) # Use caching for efficiency
49
+ past_key_values = outputs.past_key_values # Store past key/values for faster generation
50
+ first_token_logits = F.softmax(outputs.logits[0, -1, :], dim=-1) # Probabilities of the next token
51
+ top_k_logits, top_k_indices = torch.topk(first_token_logits, k) # Get top k logits and indices
52
+ cumulative_probs = torch.cumsum(top_k_logits, dim=0) # Calculate cumulative probabilities for top-p
53
+ top_p_mask = cumulative_probs <= 0.9 # Apply top-p filtering (nucleus sampling)
54
+ if not torch.any(top_p_mask): # Ensure at least one token is selected
55
+ top_p_mask[0] = True
56
+ top_k_logits, top_k_indices = top_k_logits[top_p_mask], top_k_indices[top_p_mask] # Filter based on top-p
57
+
58
+
59
+ min_diff, best_idx, past_key_values = float('inf'), None, None # Initialize variables for best token selection
60
+ new_attention_mask = torch.cat([attention_mask, torch.ones(1, 1).long().to("cuda")], dim=-1) # Extend attention mask
61
+
62
+
63
+ # Speculative decoding: Evaluate top-k tokens
64
+ for idx in top_k_indices: # Iterate through top-k candidate tokens.
65
+ new_token = idx.unsqueeze(0).unsqueeze(0) # Prepare the token for generation
66
+ new_tokens = torch.cat([input_ids, new_token], dim=-1) # Add the token to the input sequence
67
+ with torch.no_grad():
68
+ output = model.generate( # Generate one token to evaluate entropy
69
+ new_tokens,
70
+ attention_mask=new_attention_mask,
71
+ max_new_tokens=1,
72
+ output_scores=True,
73
+ output_attentions=True, # Needed for attention entropy calculation
74
+ return_dict_in_generate=True,
75
+ past_key_values=past_key_values
76
+ )
77
+ all_attentions = output.attentions[0][-1] # Extract last layer's attention weights
78
+ attn_probs = F.softmax(all_attentions[:, -1, :], dim=-1) # Calculate attention probabilities
79
+ entropy = -torch.sum(attn_probs * torch.log2(attn_probs + 1e-12), dim=-1) # Calculate entropy
80
+ avg_entropy, avg_varentropy = torch.mean(entropy), torch.var(entropy) # Compute mean and variance of entropy
81
+ diff = avg_entropy * 0.8 + avg_varentropy * 0.2 # Combine entropy metrics (weighted average)
82
+
83
+ if diff < min_diff: # Select token with lowest entropy difference
84
+ min_diff, best_idx = diff, idx
85
+
86
+ new_token = best_idx.unsqueeze(0).unsqueeze(0) # Prepare the chosen best token.
87
+ new_tokens = torch.cat([input_ids, new_token], dim=-1) # Append the token to the input sequence.
88
+
89
+ # Generate the full sequence with the chosen best first token.
90
+ output = model.generate(
91
+ new_tokens,
92
+ attention_mask=new_attention_mask,
93
+ max_new_tokens=max_new_tokens,
94
+ output_scores=True,
95
+ return_dict_in_generate=True,
96
+ past_key_values=past_key_values
97
+ )
98
+
99
+ answer_ids = output.sequences[0][len(input_ids[0]):] # Extract generated tokens
100
+ answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) # Decode to text
101
+
102
+ sum_confidence = 0
103
+ for step in range(len(output.scores)):
104
+ logits_step = output.scores[step][0] # Logits for current step
105
+ probs_step = F.softmax(logits_step, dim=-1) # Probabilities for current step
106
+ top_probs, _ = torch.topk(probs_step, k=2, dim=-1) # Get top 2 probabilities
107
+ confidence = top_probs[0] - top_probs[1] # Calculate confidence as difference between top 2
108
+ sum_confidence += confidence # Accumulate confidence over all steps
109
+ avg_confidence = sum_confidence / len(answer_ids) # Average confidence per token
110
+ avg_confidence = avg_confidence - 0.2 if avg_confidence >= 0.9 else avg_confidence # Adjust confidence if too high
111
+
112
+
113
+ path_score = avg_confidence ** (min_diff) * (len(answer_ids) / max_new_tokens) # Calculate path score
114
+ return answer_text, round(path_score.item() ** 0.33, 4) # Return generated text and score