Ethan Shen
Initial commit
dda1539
raw
history blame
No virus
2.77 kB
import torch
def log_prob_to_prob(log_probs, temp=1):
"""
Convert log probabilities to probability distribution and normalize.
Args:
log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size)
Returns:
Probability distribution (n_prompts, n_drafts, vocab_size)
"""
# stability constant
log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0]
probs = torch.softmax(log_probs / temp, dim=-1)
return probs
def decode(tokenizer, encoding):
"""
Decode a list of tokens to a string
Args:
tokenizer (Any): Tokenizer
encoding (torch.Tensor): Encoding
Returns:
decoding (str)
"""
pad_locs = (encoding == -1).nonzero()
if len(pad_locs > 0):
encoding = encoding[:pad_locs[0].item()]
return tokenizer.decode(encoding.to(torch.int32).tolist())
def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file):
"""
Print out generations for debugging.
Args:
gens (n_prompts * n_drafts, seq_len): Generations to print
logprobs (n_prompts * n_drafts): Log probs of each generation
tokenizer (any): Tokenizer
n_drafts (int): Number of drafts per prompt
prompt_len (int): Number of tokens in prompt
"""
n_prompts, n_drafts, seq_len = gens.shape
gens = gens.reshape(-1, seq_len)
logprobs = logprobs.flatten()
count = 0
for i in range(len(gens)):
d = decode(tokenizer, gens[i])
# first draft of this prompt
if i % n_drafts == 0:
count = 0
print("---------------", file=output_file)
prompt = decode(tokenizer, gens[i][:prompt_len])
print(f"prompt: {prompt}", file=output_file)
print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file)
count += 1
def print_probs(next_probs, tokenizer, output_file):
"""
Print out next token options and probabilities for debugging
Args:
next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size)
tokenizer (any): Tokenizer
"""
print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file)
n_prompts, n_drafts, vocab_size = next_probs.shape
for p_idx in range(n_prompts):
print(f"\tPrompt {p_idx}:", file=output_file)
for d_idx in range(n_drafts):
next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1)
print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file)
print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file)
print(f"\t\tProbs: {next_token_probs}", file=output_file)