Spaces:
Runtime error
Runtime error
"""A layer that samples the next tokens from the model's outputs.""" | |
import itertools | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from vllm.model_executor.layers.ops.sample import sample as sample_triton | |
from vllm.model_executor.sampling_metadata import (SamplingMetadata, | |
SamplingTensors) | |
from vllm.sampling_params import SamplingParams, SamplingType | |
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, | |
SamplerOutput, SequenceData, SequenceGroupOutput, | |
SequenceOutput) | |
class Sampler(nn.Module): | |
"""Samples the next tokens from the model's outputs. | |
This layer does the following: | |
1. Discard the hidden states that are not used for sampling (i.e., all | |
tokens except the final one in each prompt). | |
2. Compute the logits for the next tokens. | |
3. Apply presence, frequency and repetition penalties. | |
4. Apply temperature scaling. | |
5. Apply top-p and top-k truncation. | |
6. Sample the next tokens. | |
Here, each sequence group within the batch can have different sampling | |
parameters (e.g., sampling method, temperature, top-p, top-k, etc.). | |
The structure of the logits tensor is coupled with the seq_groups in | |
sampling_metadata. Typically, each sequence in each seq_group has one row in | |
logits for the next token to be sampled; however, for a seq_group with a | |
prompt request with the prompt_logprobs sampling parameter, there are rows | |
in logits for each token in the input prompt. | |
""" | |
def __init__(self, cfg_scale=1.0): | |
super().__init__() | |
self.cfg_scale = cfg_scale | |
# Whether or not the SamplerOutput should have on-device tensors | |
# containing the sampled token ids and probabilities. This is used by | |
# speculative decoding. | |
self.include_gpu_probs_tensor = False | |
def forward( | |
self, | |
logits: torch.Tensor, | |
sampling_metadata: SamplingMetadata, | |
) -> Optional[SamplerOutput]: | |
assert logits is not None | |
_, vocab_size = logits.shape | |
if self.cfg_scale > 1.0: | |
logits_combined = logits | |
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) | |
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_scale | |
logits = torch.cat([logits, logits], dim=0) | |
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens | |
# have not been generated yet | |
logits = _apply_min_tokens_penalty(logits, sampling_metadata) | |
# Prepare sampling tensors with pinned memory to avoid blocking. | |
(sampling_tensors, do_penalties, do_top_p_top_k, | |
do_min_p) = SamplingTensors.from_sampling_metadata( | |
sampling_metadata, vocab_size, logits.device, logits.dtype) | |
# Apply presence and frequency penalties. | |
if do_penalties: | |
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, | |
sampling_tensors.output_tokens, | |
sampling_tensors.presence_penalties, | |
sampling_tensors.frequency_penalties, | |
sampling_tensors.repetition_penalties) | |
# Apply temperature scaling. | |
# Use in-place division to avoid creating a new tensor. | |
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) | |
if do_top_p_top_k: | |
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, | |
sampling_tensors.top_ks) | |
if do_min_p: | |
logits = _apply_min_p(logits, sampling_tensors.min_ps) | |
# We use float32 for probabilities and log probabilities. | |
# Compute the probabilities. | |
probs = torch.softmax(logits, dim=-1, dtype=torch.float) | |
# Compute the log probabilities. | |
# Use log_softmax to ensure numerical stability. | |
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) | |
# Sample the next tokens. | |
sample_results, maybe_sampled_tokens_tensor = _sample( | |
probs, | |
logprobs, | |
sampling_metadata, | |
sampling_tensors, | |
include_gpu_probs_tensor=self.include_gpu_probs_tensor, | |
modify_greedy_probs=self._should_modify_greedy_probs_inplace, | |
) | |
if self.cfg_scale > 1.0: | |
cond_result = sample_results[:len(sample_results) // 2] | |
sample_results = cond_result + cond_result | |
if self.include_gpu_probs_tensor: | |
assert maybe_sampled_tokens_tensor is not None | |
sampled_tokens_tensor = maybe_sampled_tokens_tensor | |
on_device_tensors = (probs, sampled_tokens_tensor) | |
else: | |
on_device_tensors = None | |
# Get the logprobs query results. | |
prompt_logprobs, sample_logprobs = _get_logprobs( | |
logprobs, sampling_metadata, sample_results) | |
return _build_sampler_output(sample_results, | |
sampling_metadata, | |
prompt_logprobs, | |
sample_logprobs, | |
on_device_tensors=on_device_tensors) | |
def _should_modify_greedy_probs_inplace(self) -> bool: | |
"""Whether or not the sampler should modify the probability distribution | |
of greedily-sampled tokens such that multinomial sampling would sample | |
the greedily-sampled token. | |
In other words, if True then we set the probability of the greedily- | |
sampled token to 1. | |
This is used by speculative decoding, which requires that the sampling | |
method be encoded into the probability distribution. | |
""" | |
# Modify greedy probs if include_gpu_probs_tensor is set. | |
return self.include_gpu_probs_tensor | |
def _get_bin_counts_and_mask( | |
tokens: torch.Tensor, | |
vocab_size: int, | |
num_seqs: int, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Compute the bin counts for the tokens. | |
# vocab_size + 1 for padding. | |
bin_counts = torch.zeros((num_seqs, vocab_size + 1), | |
dtype=torch.long, | |
device=tokens.device) | |
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) | |
bin_counts = bin_counts[:, :vocab_size] | |
mask = bin_counts > 0 | |
return bin_counts, mask | |
def _apply_min_tokens_penalty( | |
logits: torch.Tensor, | |
sampling_metadata: SamplingMetadata, | |
) -> torch.Tensor: | |
# list of indices in logits that will be set to -inf | |
logits_to_penalize = [] | |
start_idx = 0 | |
for i, seq_group in enumerate(sampling_metadata.seq_groups): | |
seq_ids, sampling_params = seq_group | |
# handle prompt_logprobs by skipping rows in logits added for the prompt | |
# tokens (prompt logprobs are not penalized) | |
if (i < sampling_metadata.num_prompts | |
and sampling_params.prompt_logprobs is not None): | |
assert len(seq_ids) == 1 | |
start_idx += sampling_metadata.prompt_lens[i] - 1 | |
min_tokens = sampling_params.min_tokens | |
if min_tokens > 0: | |
seqs_to_penalize = [] | |
for i, seq_id in enumerate(seq_ids): | |
seq_data = sampling_metadata.seq_data[seq_id] | |
if len(seq_data.output_token_ids) < min_tokens: | |
seqs_to_penalize.append(i) | |
if seqs_to_penalize: | |
# convert to the index into logits | |
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] | |
# use set() to remove any duplicates | |
token_ids_to_penalize = set(sampling_params.stop_token_ids + | |
[sampling_params.eos_token_id]) | |
# itertools.product pairs each seq index with every token id | |
logits_to_penalize.extend( | |
itertools.product(seqs_to_penalize, token_ids_to_penalize)) | |
start_idx += len(seq_ids) | |
if logits_to_penalize: | |
# use zip and * to group indices along each dimension | |
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) | |
logits[tuple(zip(*logits_to_penalize))] = -float("inf") | |
# verifies that no rows in logits were missed unexpectedly | |
assert start_idx == logits.shape[0] | |
return logits | |
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, | |
output_tokens_tensor: torch.Tensor, | |
presence_penalties: torch.Tensor, | |
frequency_penalties: torch.Tensor, | |
repetition_penalties: torch.Tensor) -> torch.Tensor: | |
num_seqs, vocab_size = logits.shape | |
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, | |
num_seqs) | |
output_bin_counts, output_mask = _get_bin_counts_and_mask( | |
output_tokens_tensor, vocab_size, num_seqs) | |
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) | |
repetition_penalties[~(prompt_mask | output_mask)] = 1.0 | |
logits = torch.where(logits > 0, logits / repetition_penalties, | |
logits * repetition_penalties) | |
# We follow the definition in OpenAI API. | |
# Refer to https://platform.openai.com/docs/api-reference/parameter-details | |
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts | |
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask | |
return logits | |
def _apply_top_k_top_p( | |
logits: torch.Tensor, | |
p: torch.Tensor, | |
k: torch.Tensor, | |
) -> torch.Tensor: | |
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | |
# Apply top-k. | |
top_k_mask = logits_sort.size(1) - k.to(torch.long) | |
# Get all the top_k values. | |
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) | |
top_k_mask = logits_sort < top_k_mask | |
logits_sort.masked_fill_(top_k_mask, -float("inf")) | |
# Apply top-p. | |
probs_sort = logits_sort.softmax(dim=-1) | |
probs_sum = probs_sort.cumsum(dim=-1) | |
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) | |
# at least one | |
top_p_mask[:, -1] = False | |
logits_sort.masked_fill_(top_p_mask, -float("inf")) | |
# Re-sort the probabilities. | |
src = torch.arange(logits_idx.shape[-1], | |
device=logits_idx.device).expand_as(logits_idx) | |
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, | |
index=logits_idx, | |
src=src) | |
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) | |
return logits | |
def _apply_min_p( | |
logits: torch.Tensor, | |
min_p: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Adapted from | |
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 | |
""" | |
probs = torch.softmax(logits, dim=-1) | |
top_probs, _ = probs.max(dim=-1, keepdim=True) | |
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs | |
tokens_to_remove = probs < scaled_min_p | |
logits = logits.masked_fill_(tokens_to_remove, -float("inf")) | |
return logits | |
def _greedy_sample( | |
selected_seq_groups: List[Tuple[List[int], SamplingParams]], | |
samples: torch.Tensor, | |
) -> List[Tuple[List[int], List[int]]]: | |
samples = samples.tolist() | |
sample_idx = 0 | |
results = [] | |
for seq_group in selected_seq_groups: | |
seq_ids, _ = seq_group | |
num_parent_seqs = len(seq_ids) | |
assert num_parent_seqs == 1, ( | |
"Greedy sampling should have only one seq.") | |
parent_ids = list(range(num_parent_seqs)) | |
next_token_ids = [samples[sample_idx]] | |
results.append((next_token_ids, parent_ids)) | |
sample_idx += num_parent_seqs | |
return results | |
def _random_sample( | |
selected_seq_groups: List[Tuple[List[int], SamplingParams]], | |
is_prompts: List[bool], | |
random_samples: torch.Tensor, | |
) -> List[Tuple[List[int], List[int]]]: | |
# Find the maximum best_of value of the prompt phase requests. | |
random_samples = random_samples.cpu() | |
sample_idx = 0 | |
results = [] | |
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): | |
seq_ids, sampling_params = seq_group | |
num_parent_seqs = len(seq_ids) | |
if is_prompt: | |
# Prompt phase. | |
parent_ids = [0] * sampling_params.best_of | |
next_token_ids = random_samples[ | |
sample_idx, :sampling_params.best_of].tolist() | |
else: | |
# Generation phase. | |
parent_ids = list(range(num_parent_seqs)) | |
next_token_ids = random_samples[sample_idx:sample_idx + | |
num_parent_seqs, 0].tolist() | |
results.append((next_token_ids, parent_ids)) | |
sample_idx += num_parent_seqs | |
return results | |
def _beam_search_sample( | |
selected_seq_groups: List[Tuple[List[int], SamplingParams]], | |
is_prompts: List[bool], | |
seq_data: Dict[int, SequenceData], | |
logprobs: torch.Tensor, | |
) -> List[Tuple[List[int], List[int]]]: | |
# We sample 2 * beam_width candidates to make sure that with high | |
# probability we can get `beam_width` candidates in addition to | |
# the finished sequences for the next iteration. See | |
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 | |
# for details. See also HF reference: | |
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 | |
# | |
# NOTE: Beam search is not vectorized, so its speed can be slower than | |
# other sampling methods. | |
sample_idx = 0 | |
results = [] | |
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): | |
seq_ids, sampling_params = seq_group | |
num_parent_seqs = len(seq_ids) | |
beam_width = sampling_params.best_of | |
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] | |
if is_prompt: | |
# Prompt phase. | |
assert num_parent_seqs == 1, ( | |
"Prompt input should have only one seq.") | |
parent_ids = [0] * (2 * beam_width) | |
_, next_token_ids = torch.topk(seq_group_logprobs[0], | |
2 * beam_width) | |
next_token_ids = next_token_ids.tolist() | |
else: | |
# Generation phase. | |
cumulative_logprobs = [ | |
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids | |
] | |
cumulative_logprobs = torch.tensor( | |
cumulative_logprobs, | |
dtype=torch.float, | |
device=seq_group_logprobs.device) | |
seq_group_logprobs = (seq_group_logprobs + | |
cumulative_logprobs.unsqueeze(dim=1)) | |
_, topk_ids = torch.topk(seq_group_logprobs.flatten(), | |
2 * beam_width) | |
topk_ids = topk_ids.tolist() | |
vocab_size = seq_group_logprobs.size(-1) | |
parent_ids = [i // vocab_size for i in topk_ids] | |
next_token_ids = [i % vocab_size for i in topk_ids] | |
results.append((next_token_ids, parent_ids)) | |
sample_idx += num_parent_seqs | |
assert sample_idx == logprobs.size(0) | |
return results | |
# torch.multinomial forces a GPU<->CPU sync. | |
# Therefore, we use an optimized implementation instead. | |
# Note that we always sample with replacement. | |
# probs will be modified in place, but this is fine, as we pass | |
# in a copy already. | |
def _multinomial( | |
probs: torch.Tensor, | |
num_samples: int, | |
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, | |
generators: Optional[List[torch.Generator]] = None, | |
) -> torch.Tensor: | |
if num_samples > 1: | |
# This is equivalent to torch.repeat_interleaved (which also | |
# forces a GPU<->CPU sync). | |
# This allows us to do sampling with replacement by creating | |
# num_samples copies of each row in the tensor, and then | |
# batch sampling the resulting tensor. | |
probs = probs[:, None, :].expand(probs.shape[0], num_samples, | |
probs.shape[1]).contiguous().view( | |
-1, probs.shape[1]) | |
q = torch.empty_like(probs) | |
if seq_groups is None: | |
q.exponential_() | |
else: | |
sample_idx = 0 | |
for (seq_ids, _), generator in zip(seq_groups, generators): | |
next_sample_idx = sample_idx + len(seq_ids) * num_samples | |
q[sample_idx:next_sample_idx].exponential_(generator=generator) | |
sample_idx = next_sample_idx | |
return probs.div_(q).argmax(dim=1).view(-1, num_samples) | |
def _sample_with_torch( | |
probs: torch.Tensor, | |
logprobs: torch.Tensor, | |
sampling_metadata: SamplingMetadata, | |
include_gpu_probs_tensor: bool, | |
modify_greedy_probs: bool, | |
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: | |
categorized_seq_group_ids = {t: [] for t in SamplingType} | |
categorized_sample_indices = sampling_metadata.categorized_sample_indices | |
for i, seq_group in enumerate(sampling_metadata.seq_groups): | |
_, sampling_params = seq_group | |
sampling_type = sampling_params.sampling_type | |
categorized_seq_group_ids[sampling_type].append(i) | |
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} | |
sample_metadata = {} | |
multinomial_samples = {} | |
# Create output tensor for sampled token ids. | |
if include_gpu_probs_tensor: | |
sampled_token_ids_tensor = torch.empty(logprobs.shape[0], | |
1, | |
dtype=torch.long, | |
device=logprobs.device) | |
else: | |
sampled_token_ids_tensor = None | |
# Counterintiutively, having two loops here is actually faster. | |
# The first loop can run without waiting on GPU<->CPU sync. | |
for sampling_type in SamplingType: | |
sample_indices = categorized_sample_indices[sampling_type][:, 0] | |
num_tokens = len(sample_indices) | |
if num_tokens == 0: | |
continue | |
seq_group_ids = categorized_seq_group_ids[sampling_type] | |
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] | |
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] | |
sample_metadata[sampling_type] = (seq_group_ids, seq_groups, | |
is_prompts, sample_indices) | |
long_sample_indices = sample_indices.long() | |
if sampling_type == SamplingType.GREEDY: | |
greedy_samples = torch.argmax(logprobs[long_sample_indices], | |
dim=-1) | |
if include_gpu_probs_tensor: | |
# Store sampled tokens in output tensor. | |
sampled_token_ids_tensor[ | |
long_sample_indices] = greedy_samples.unsqueeze(-1) | |
if modify_greedy_probs: | |
# If required, modify the probabilities such that sampling from | |
# the modified distribution would always sample the argmax | |
# token id. | |
_modify_greedy_probs_inplace(logprobs, probs, | |
long_sample_indices, | |
greedy_samples) | |
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): | |
max_best_of_in_batch = 1 | |
for seq_group, is_prompt in zip(seq_groups, is_prompts): | |
if is_prompt: | |
_, sampling_params = seq_group | |
max_best_of_in_batch = max(max_best_of_in_batch, | |
sampling_params.best_of) | |
seeded_args = {} if sampling_type == SamplingType.RANDOM else { | |
"seq_groups": seq_groups, | |
"generators": sampling_metadata.generators, | |
} | |
multinomial_samples[sampling_type] = _multinomial( | |
probs[long_sample_indices], max_best_of_in_batch, | |
**seeded_args) | |
if include_gpu_probs_tensor: | |
# Store sampled tokens in output tensor. | |
sampled_token_ids_tensor[ | |
long_sample_indices] = multinomial_samples[sampling_type] | |
elif sampling_type == SamplingType.BEAM: | |
beam_search_logprobs = logprobs[sample_indices] | |
else: | |
raise ValueError(f"Unsupported sampling type: {sampling_type}") | |
# GPU<->CPU sync happens in the loop below. | |
# This also converts the sample output to Python objects. | |
for sampling_type in SamplingType: | |
if sampling_type not in sample_metadata: | |
continue | |
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ | |
sampling_type] | |
if sampling_type == SamplingType.GREEDY: | |
sample_results = _greedy_sample(seq_groups, greedy_samples) | |
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): | |
sample_results = _random_sample(seq_groups, is_prompts, | |
multinomial_samples[sampling_type]) | |
elif sampling_type == SamplingType.BEAM: | |
sample_results = _beam_search_sample(seq_groups, is_prompts, | |
sampling_metadata.seq_data, | |
beam_search_logprobs) | |
sample_results_dict.update(zip(seq_group_ids, sample_results)) | |
sample_results = [ | |
sample_results_dict[i] | |
for i in range(len(sampling_metadata.seq_groups)) | |
] | |
return sample_results, sampled_token_ids_tensor | |
def _sample_with_triton_kernel( | |
probs: torch.Tensor, | |
logprobs: torch.Tensor, | |
sampling_metadata: SamplingMetadata, | |
sampling_tensors: SamplingTensors, | |
) -> List[Tuple[List[int], List[int]]]: | |
categorized_seq_group_ids = {t: [] for t in SamplingType} | |
categorized_sample_indices = sampling_metadata.categorized_sample_indices | |
for i, seq_group in enumerate(sampling_metadata.seq_groups): | |
_, sampling_params = seq_group | |
sampling_type = sampling_params.sampling_type | |
categorized_seq_group_ids[sampling_type].append(i) | |
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} | |
sample_metadata = {} | |
max_best_of_in_batch = 1 | |
# Counterintiutively, having two loops here is actually faster. | |
# The first loop can run without waiting on GPU<->CPU sync. | |
for sampling_type in SamplingType: | |
sample_indices = categorized_sample_indices[sampling_type][:, 0] | |
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] | |
num_tokens = len(sample_indices) | |
if num_tokens == 0: | |
continue | |
seq_group_ids = categorized_seq_group_ids[sampling_type] | |
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] | |
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] | |
sample_metadata[sampling_type] = (seq_group_ids, seq_groups, | |
is_prompts, sample_indices, | |
sampled_token_indices) | |
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, | |
SamplingType.RANDOM_SEED): | |
for seq_group, is_prompt in zip(seq_groups, is_prompts): | |
if is_prompt: | |
_, sampling_params = seq_group | |
max_best_of_in_batch = max(max_best_of_in_batch, | |
sampling_params.best_of) | |
elif sampling_type == SamplingType.BEAM: | |
beam_search_logprobs = logprobs[sample_indices] | |
else: | |
raise ValueError(f"Unsupported sampling type: {sampling_type}") | |
sampled_tokens, _, _ = sample_triton( | |
probs=probs, | |
seeds=sampling_tensors.sampling_seeds, | |
max_best_of=max_best_of_in_batch, | |
sample_indices=sampling_tensors.sample_indices, | |
logprobs=logprobs, | |
# don't save logprobs because we have logic for that below | |
# TODO: use this instead of the CPU-based logic below | |
save_logprobs=False, | |
) | |
# GPU<->CPU sync happens in the loop below. | |
for sampling_type in SamplingType: | |
if sampling_type not in sample_metadata: | |
continue | |
(seq_group_ids, seq_groups, is_prompts, sample_indices, | |
sampled_token_indices) = sample_metadata[sampling_type] | |
if sampling_type == SamplingType.GREEDY: | |
sample_results = _greedy_sample( | |
seq_groups, sampled_tokens[sampled_token_indices][:, 0]) | |
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): | |
sample_results = _random_sample( | |
seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) | |
elif sampling_type == SamplingType.BEAM: | |
sample_results = _beam_search_sample(seq_groups, is_prompts, | |
sampling_metadata.seq_data, | |
beam_search_logprobs) | |
sample_results_dict.update(zip(seq_group_ids, sample_results)) | |
sample_results = [ | |
sample_results_dict[i] | |
for i in range(len(sampling_metadata.seq_groups)) | |
] | |
return sample_results | |
def _sample( | |
probs: torch.Tensor, logprobs: torch.Tensor, | |
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, | |
include_gpu_probs_tensor: bool, modify_greedy_probs: bool | |
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: | |
return _sample_with_torch( | |
probs, | |
logprobs, | |
sampling_metadata, | |
include_gpu_probs_tensor=include_gpu_probs_tensor, | |
modify_greedy_probs=modify_greedy_probs, | |
) | |
# TODO: Enable once Triton kernel & associated code is faster. | |
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, | |
# sampling_tensors) | |
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: | |
""" | |
This function calculates the ranks of the chosen tokens in a logprob tensor. | |
Args: | |
x (torch.Tensor): 2D logprob tensor of shape (N, M) | |
where N is the no. of tokens and M is the vocab dim. | |
indices (torch.Tensor): List of chosen token indices. | |
Returns: | |
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. | |
Each element in the returned tensor represents the rank | |
of the chosen token in the input logprob tensor. | |
""" | |
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), | |
indices] | |
return (x > vals[:, None]).long().sum(1).add_(1) | |
def _get_logprobs( | |
logprobs: torch.Tensor, | |
sampling_metadata: SamplingMetadata, | |
sample_results: List[Tuple[List[int], List[int]]], | |
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ | |
int, float]]]]: | |
# Prepare query indices | |
batched_logprobs_query_seq_indices: List[int] = [] | |
batched_logprobs_query_token_indices: List[int] = [] | |
# at least get one logprob for each token | |
largest_num_logprobs = 1 | |
sample_idx = 0 | |
for i, (seq_group, sample_result) in enumerate( | |
zip(sampling_metadata.seq_groups, sample_results)): | |
seq_ids, sampling_params = seq_group | |
next_token_ids, parent_ids = sample_result | |
num_parent_seqs = len(seq_ids) | |
if (i < sampling_metadata.num_prompts | |
and sampling_params.prompt_logprobs is not None): | |
largest_num_logprobs = max(largest_num_logprobs, | |
sampling_params.prompt_logprobs) | |
prompt_len = sampling_metadata.prompt_lens[i] | |
prompt_tokens = sampling_metadata.seq_data[ | |
seq_ids[0]].prompt_token_ids | |
batched_logprobs_query_seq_indices.extend( | |
sample_idx + j for j in range(prompt_len - 1)) | |
batched_logprobs_query_token_indices.extend( | |
token_id for token_id in prompt_tokens[1:]) | |
sample_idx += prompt_len - 1 | |
batched_logprobs_query_seq_indices.extend( | |
[sample_idx + parent_id for parent_id in parent_ids]) | |
batched_logprobs_query_token_indices.extend(next_token_ids) | |
if sampling_params.logprobs is not None: | |
largest_num_logprobs = max(largest_num_logprobs, | |
sampling_params.logprobs) | |
sample_idx += num_parent_seqs | |
assert sample_idx == logprobs.size(0) | |
batched_logprobs_query_seq_indices_gpu = torch.tensor( | |
batched_logprobs_query_seq_indices, device=logprobs.device) | |
batched_logprobs_query_token_indices_gpu = torch.tensor( | |
batched_logprobs_query_token_indices, device=logprobs.device) | |
# Batched query for logprobs of selected token | |
batched_logprobs_query_result = logprobs[[ | |
batched_logprobs_query_seq_indices_gpu, | |
batched_logprobs_query_token_indices_gpu | |
]] | |
batched_ranks_query_result = _get_ranks( | |
logprobs[batched_logprobs_query_seq_indices_gpu], | |
batched_logprobs_query_token_indices_gpu) | |
# Batched query for logprobs of topk tokens | |
if largest_num_logprobs > 0: | |
top_logprobs, top_token_ids = torch.topk(logprobs, | |
largest_num_logprobs, | |
dim=-1) | |
top_logprobs = top_logprobs.cpu() | |
top_token_ids = top_token_ids.cpu() | |
else: | |
top_logprobs, top_token_ids = None, None | |
batched_logprobs_query_result = batched_logprobs_query_result.cpu() | |
batched_ranks_query_result = batched_ranks_query_result.cpu() | |
# Gather results | |
result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] | |
result_sample_logprobs: List[SampleLogprobs] = [] | |
sample_idx = 0 | |
query_result_idx = 0 | |
for i, (seq_group, sample_result) in enumerate( | |
zip(sampling_metadata.seq_groups, sample_results)): | |
seq_ids, sampling_params = seq_group | |
next_token_ids, parent_ids = sample_result | |
# Prompt logprobs | |
if (i < sampling_metadata.num_prompts | |
and sampling_params.prompt_logprobs is not None): | |
num_logprobs = sampling_params.prompt_logprobs | |
prompt_tokens = sampling_metadata.seq_data[ | |
seq_ids[0]].prompt_token_ids | |
group_prompt_logprobs: PromptLogprobs = [None] | |
for token_id in prompt_tokens[1:]: | |
prompt_logprobs_dict = { | |
token_id: | |
(batched_logprobs_query_result[query_result_idx].item(), | |
batched_ranks_query_result[query_result_idx].item()) | |
} | |
if num_logprobs > 0: | |
prompt_logprobs_dict.update( | |
zip( | |
top_token_ids[sample_idx, :num_logprobs].tolist(), | |
zip( | |
top_logprobs[ | |
sample_idx, :num_logprobs].tolist(), | |
range(1, num_logprobs + 1)))) | |
group_prompt_logprobs.append({ | |
token_id: Logprob(*logprob_rank) | |
for token_id, logprob_rank in prompt_logprobs_dict.items() | |
}) | |
sample_idx += 1 | |
query_result_idx += 1 | |
result_prompt_logprobs.append(group_prompt_logprobs) | |
else: | |
result_prompt_logprobs.append(None) | |
# Sample logprobs | |
num_logprobs = sampling_params.logprobs | |
if num_logprobs is None: | |
num_logprobs = 0 | |
group_sample_logprobs: SampleLogprobs = [] | |
for next_token_id, parent_id in zip(next_token_ids, parent_ids): | |
sample_logprobs_dict = { | |
next_token_id: | |
(batched_logprobs_query_result[query_result_idx].item(), | |
batched_ranks_query_result[query_result_idx].item()) | |
} | |
query_result_idx += 1 | |
if num_logprobs >= 0: | |
sample_logprobs_dict.update( | |
zip( | |
top_token_ids[sample_idx + | |
parent_id, :num_logprobs].tolist(), | |
zip( | |
top_logprobs[sample_idx + | |
parent_id, :num_logprobs].tolist(), | |
range(1, num_logprobs + 1)))) | |
group_sample_logprobs.append({ | |
token_id: Logprob(*logprob_rank) | |
for token_id, logprob_rank in sample_logprobs_dict.items() | |
}) | |
result_sample_logprobs.append(group_sample_logprobs) | |
sample_idx += len(seq_ids) | |
return result_prompt_logprobs, result_sample_logprobs | |
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, | |
sample_indices: torch.Tensor, | |
greedy_samples: torch.Tensor) -> None: | |
"""Modify the probability distributions of the greedily-sampled tokens such | |
that each sampled token has a "probability" of 1.0. This is required by | |
speculative decoding, which depends on the sampling method being encoded | |
within the probability distribution for correctness. | |
# Why do we only need to do this for greedy sampling? | |
vLLM's sampler performs the following steps for greedy or multinomial | |
(random) sampling: | |
1. Get logits from model. | |
2. Modify logits according to per-sequence sampling parameters. | |
- Multiply by temperature, top-k and top-p masking, penalize tokens | |
according to their frequency, etc. | |
3. Sample a token. | |
- Random sampling simply samples from the modified probability | |
distribution. | |
- Greedy sampling performs `argmax` to obtain the token with the | |
highest likelihood. | |
Ignoring greedy sampling for a moment, we find that the computed probability | |
distribution has the following property: we can sample from it independently | |
and find that the token sampled by the Sampler has a frequency corresponding | |
to how often we see it in our sampling. In other words, for tokens sampled | |
with vLLM's random SamplingType, the computed probability distribution | |
encodes the sampling methodology completely. | |
Greedy sampling does not normally have this property. vLLM modifies logits | |
according to sampling params, then performs `argmax`, then returns the | |
sampled token and the computed probability distribution. If we sample from | |
the distribution, we'll find the likelihood of the greedily-sampled token | |
is not always 1.0. | |
Since lossless speculative decoding requires that the sampling methodology | |
be encoded within the probability distribution, we are motivated to modify | |
the probability distribution such that the sampled token has probability 1 | |
when speculative decoding is used. | |
NOTE: Alternatively, we could use an extremely low temperature to achieve | |
greedy sampling using multinomial computation and unite the codepaths. This | |
has implications on the overall design of the sampler, e.g. how to record | |
accurate logprobs for the user, so this improvement is deferred to later. | |
""" | |
logprobs[sample_indices, :] = -float('inf') | |
logprobs[sample_indices, greedy_samples] = 0.0 | |
probs[sample_indices, :] = 0 | |
probs[sample_indices, greedy_samples] = 1.0 | |
def _build_sampler_output( | |
sample_results: List[Tuple[List[int], List[int]]], | |
sampling_metadata: SamplingMetadata, | |
prompt_logprobs: List[Optional[PromptLogprobs]], | |
sample_logprobs: List[SampleLogprobs], | |
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]], | |
) -> SamplerOutput: | |
"""Construct Python objects with the output of sampling. | |
Args: | |
on_device_tensors: Tuple containing on-device tensors with the | |
probabilities used in sampling and the sampled token ids. This | |
allows post-processing without copies to CPU/serialization, e.g. in | |
speculative decoding rejection sampling. | |
""" | |
sampler_output = [] | |
for (seq_group, sample_result, group_prompt_logprobs, | |
group_sample_logprobs) in zip(sampling_metadata.seq_groups, | |
sample_results, prompt_logprobs, | |
sample_logprobs): | |
seq_ids, _ = seq_group | |
next_token_ids, parent_ids = sample_result | |
seq_outputs = [] | |
for parent_id, next_token_id, logprobs in zip(parent_ids, | |
next_token_ids, | |
group_sample_logprobs): | |
seq_outputs.append( | |
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) | |
sampler_output.append( | |
SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) | |
# If not specified, store None values in SamplerOutput. | |
if on_device_tensors is not None: | |
sampled_token_probs, sampled_token_ids = on_device_tensors | |
else: | |
sampled_token_probs, sampled_token_ids = (None, None) | |
return SamplerOutput( | |
outputs=sampler_output, | |
sampled_token_probs=sampled_token_probs, | |
sampled_token_ids=sampled_token_ids, | |
) | |