Qwen2.5-7B-Anvita / entropic_cot.py
sethuiyer's picture
Create entropic_cot.py
05adcf7 verified
raw
history blame
6.22 kB
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Tuple
def cot_decode_speculative(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
messages: List[Dict[str, str]],
k: int = 3,
max_new_tokens: int = 512
) -> Tuple[str, float]:
"""
Generates text using a speculative decoding approach with confidence and entropy-based metrics.
This function implements a Chain-of-Thought (CoT) decoding strategy that explores multiple
potential next tokens and selects the one leading to a generation with lower entropy and
higher confidence. It incorporates top-p sampling and calculates a path score based on confidence,
entropy difference, and generation length.
Args:
model: The pre-trained language model.
tokenizer: The corresponding tokenizer.
messages: A list of dictionaries, where each dictionary represents a message with "role" and "content" keys.
k: The number of top tokens to consider for speculative decoding.
max_new_tokens: The maximum number of tokens to generate.
Returns:
A tuple containing the generated text and the calculated path score.
"""
# Format the input based on tokenizer capabilities. Handles both chat template and standard formats.
if hasattr(tokenizer, 'chat_template'): # Efficiently uses chat template if available.
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else: # Fallback for tokenizers without chat template support.
input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
input_text += "\nassistant:"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda") # GPU usage specified
# Handle missing pad_token_id, common in some models.
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
attention_mask = (input_ids != tokenizer.pad_token_id).long().to("cuda")
with torch.no_grad(): # No gradient calculation needed for inference.
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) # Use caching for efficiency
past_key_values = outputs.past_key_values # Store past key/values for faster generation
first_token_logits = F.softmax(outputs.logits[0, -1, :], dim=-1) # Probabilities of the next token
top_k_logits, top_k_indices = torch.topk(first_token_logits, k) # Get top k logits and indices
cumulative_probs = torch.cumsum(top_k_logits, dim=0) # Calculate cumulative probabilities for top-p
top_p_mask = cumulative_probs <= 0.9 # Apply top-p filtering (nucleus sampling)
if not torch.any(top_p_mask): # Ensure at least one token is selected
top_p_mask[0] = True
top_k_logits, top_k_indices = top_k_logits[top_p_mask], top_k_indices[top_p_mask] # Filter based on top-p
min_diff, best_idx, past_key_values = float('inf'), None, None # Initialize variables for best token selection
new_attention_mask = torch.cat([attention_mask, torch.ones(1, 1).long().to("cuda")], dim=-1) # Extend attention mask
# Speculative decoding: Evaluate top-k tokens
for idx in top_k_indices: # Iterate through top-k candidate tokens.
new_token = idx.unsqueeze(0).unsqueeze(0) # Prepare the token for generation
new_tokens = torch.cat([input_ids, new_token], dim=-1) # Add the token to the input sequence
with torch.no_grad():
output = model.generate( # Generate one token to evaluate entropy
new_tokens,
attention_mask=new_attention_mask,
max_new_tokens=1,
output_scores=True,
output_attentions=True, # Needed for attention entropy calculation
return_dict_in_generate=True,
past_key_values=past_key_values
)
all_attentions = output.attentions[0][-1] # Extract last layer's attention weights
attn_probs = F.softmax(all_attentions[:, -1, :], dim=-1) # Calculate attention probabilities
entropy = -torch.sum(attn_probs * torch.log2(attn_probs + 1e-12), dim=-1) # Calculate entropy
avg_entropy, avg_varentropy = torch.mean(entropy), torch.var(entropy) # Compute mean and variance of entropy
diff = avg_entropy * 0.8 + avg_varentropy * 0.2 # Combine entropy metrics (weighted average)
if diff < min_diff: # Select token with lowest entropy difference
min_diff, best_idx = diff, idx
new_token = best_idx.unsqueeze(0).unsqueeze(0) # Prepare the chosen best token.
new_tokens = torch.cat([input_ids, new_token], dim=-1) # Append the token to the input sequence.
# Generate the full sequence with the chosen best first token.
output = model.generate(
new_tokens,
attention_mask=new_attention_mask,
max_new_tokens=max_new_tokens,
output_scores=True,
return_dict_in_generate=True,
past_key_values=past_key_values
)
answer_ids = output.sequences[0][len(input_ids[0]):] # Extract generated tokens
answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) # Decode to text
sum_confidence = 0
for step in range(len(output.scores)):
logits_step = output.scores[step][0] # Logits for current step
probs_step = F.softmax(logits_step, dim=-1) # Probabilities for current step
top_probs, _ = torch.topk(probs_step, k=2, dim=-1) # Get top 2 probabilities
confidence = top_probs[0] - top_probs[1] # Calculate confidence as difference between top 2
sum_confidence += confidence # Accumulate confidence over all steps
avg_confidence = sum_confidence / len(answer_ids) # Average confidence per token
avg_confidence = avg_confidence - 0.2 if avg_confidence >= 0.9 else avg_confidence # Adjust confidence if too high
path_score = avg_confidence ** (min_diff) * (len(answer_ids) / max_new_tokens) # Calculate path score
return answer_text, round(path_score.item() ** 0.33, 4) # Return generated text and score