SuperposedDecoding / superposed /llama /superposed_generation.py
Ethan Shen
Initial commit
dda1539
raw
history blame
8.27 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
import os
import sys
import time
from pathlib import Path
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from superposed.llama.model import ModelArgs
from superposed.llama.superposed_model import SuperposedTransformer
from superposed.llama.tokenizer import Tokenizer
from superposed.llama.superpose import Superpose
from superposed.llama.utils import *
from superposed.ngrams.ngram_models import make_models
class SuperposedLlama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
device = None,
model_parallel_size: Optional[int] = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device == None:
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
# Set up superposed decoding
model = SuperposedTransformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return SuperposedLlama(model, tokenizer, device)
def __init__(self, model: SuperposedTransformer, tokenizer: Tokenizer, device):
print(device)
self.model = model.to(device).eval()
self.tokenizer = tokenizer
self.device = device
@torch.inference_mode()
def sup_generate(
self,
prompt_tokens: List[List[int]],
smoothing,
max_gen_len: int,
n_token_sample: int,
alpha: int, # weight on bigram probs
temp: int,
n_drafts: int = 1, # number of beams
verbose: bool = False,
i_weights = None,
i_length = None,
ngrams = None,
get_time: bool = False,
penalty = 200
):
"""
Run multi-sequence generation using superposed embeddings.
Args:
prompt_tokens (List[List[int]]): Initial tokenized prompts
max_gen_len (int): Maximum numbers of tokens to generate
alpha (float): Alpha value
temp (float): Temperature
n_drafts (int): Number of drafts
verbose (bool): Whether to save intermediate embeddings for analysis
bsz (int): Batch size (default = 16)
i_weights (List[float]): Ngram interpolation weights
i_length (List[int]): Ngram models to interpolate (1 for bigram, 2 for trigram, etc.)
ngrams (Tuple): Ngram models
get_time (bool): Return information on time spent doing Ngram lookup
penalty (float): Penalty on uninterpolated drafts
Returns:
(alive_seq, alive_ppl), (fin_seq, fin_ppl): Tuple of (n_prompts, n_drafts, seqlen),
(n_prompts, n_drafts) for sequences still generating and sequences that have finished.
"""
# Check batch size and prompt lengths
params = self.model.params
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
prompt_len = min_prompt_len
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id
# Initialize token tensor and pad where necessary
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
for k, t in enumerate(prompt_tokens):
tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
# If no generation is possible
if min_prompt_len == total_len:
raise RuntimeError("no generation possible")
# Initialize decoding object
initial_tokens = tokens.unsqueeze(1).repeat(1, n_drafts, 1)
superpose = Superpose(initial_tokens,
tokenizer=self.tokenizer,
vocab_size=params.vocab_size,
smoothing=smoothing,
alpha=alpha,
i_weights=i_weights,
i_length=i_length,
ngrams=ngrams,
get_time=get_time,
penalty=penalty)
unseen_first = torch.ones(bsz)
# Superposition matrix
token_weights = torch.zeros(bsz, self.model.vocab_size)
if verbose:
state_list = []
prev_pos = 0
# Begin inference
for cur_pos in range(min_prompt_len, total_len):
input_text_mask = tokens != pad_id
# Take model step
if cur_pos == min_prompt_len:
token_weights = None
logits = self.model.forward(tokens[:, prev_pos:cur_pos],
start_pos=prev_pos,
token_weights=token_weights,
verbose=verbose)
if verbose:
logits, states = logits
# Softmax
if temp > 0:
probs = torch.softmax(logits[:, -1] / temp, dim=-1)
else:
raise RuntimeError("Temperature must be greater than 0 while mixing")
if verbose:
states["end_probs"] = probs
state_list.append(states)
# Flag prompts on first generation
is_first = torch.mul(tokens[:, cur_pos] == pad_id, unseen_first)
unseen_first[is_first.nonzero(as_tuple=True)[0]] = 0
# Flag prompts not yet generating
still_prompt = input_text_mask[:, cur_pos]
# Superposition pass
token_weights = superpose(probs, still_prompt, is_first, cur_pos, n_token_sample)
# Do not superpose for prompts not yet generating
keep_idx = input_text_mask[:, cur_pos].ravel().nonzero()
keep_token_weights = torch.zeros_like(token_weights)
keep_token_weights[keep_idx, tokens[keep_idx, cur_pos]] = 1
token_weights = torch.where(input_text_mask[:, cur_pos].unsqueeze(1).expand(-1, self.model.vocab_size),
keep_token_weights, token_weights)
prev_pos = cur_pos
results = superpose.return_results(prompt_len)
if verbose:
torch.save(state_list, "../embeddings.pt")
return results
else:
return results