Spaces:
Runtime error
Runtime error
import torch | |
from typing import Optional, Tuple, Union, List, Callable | |
from transformers.generation.logits_process import LogitsProcessor | |
from transformers.generation.beam_search import BeamSearchScorer | |
from transformers.deepspeed import is_deepspeed_zero3_enabled | |
from transformers.generation.utils import ( | |
LogitsProcessorList, | |
StoppingCriteriaList, | |
GenerationConfig, | |
GenerationMixin, | |
) | |
from transformers import LlamaForCausalLM | |
import warnings | |
import torch.distributed as dist | |
from torch import nn | |
import copy | |
class SteamGenerationMixin(LlamaForCausalLM): | |
# support for streamly generation | |
# TODO: group_beam_search | |
def stream_generate( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
prefix_allowed_tokens_fn: Optional[ | |
Callable[[int, torch.Tensor], List[int]] | |
] = None, | |
**kwargs, | |
): | |
self._reorder_cache = self.base_model._reorder_cache | |
if is_deepspeed_zero3_enabled() and dist.world_size() > 1: | |
synced_gpus = True | |
else: | |
synced_gpus = False | |
if kwargs.get("attention_mask", None) is not None: | |
# concat prompt attention mask | |
prefix_attention_mask = torch.ones( | |
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens | |
).to(kwargs["input_ids"].device) | |
kwargs["attention_mask"] = torch.cat( | |
(prefix_attention_mask, kwargs["attention_mask"]), dim=1 | |
) | |
if kwargs.get("position_ids", None) is not None: | |
warnings.warn( | |
"Position ids are not supported for parameter efficient tuning. Ignoring position ids." | |
) | |
kwargs["position_ids"] = None | |
if kwargs.get("token_type_ids", None) is not None: | |
warnings.warn( | |
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" | |
) | |
kwargs["token_type_ids"] = None | |
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
if generation_config is None: | |
generation_config = self.generation_config | |
generation_config = copy.deepcopy(generation_config) | |
model_kwargs = generation_config.update(**kwargs) | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
has_default_max_length = ( | |
kwargs.get("max_length") is None | |
and generation_config.max_length is not None | |
) | |
if has_default_max_length and generation_config.max_new_tokens is None: | |
warnings.warn( | |
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " | |
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" | |
" recommend using `max_new_tokens` to control the maximum length of the generation.", | |
UserWarning, | |
) | |
elif generation_config.max_new_tokens is not None: | |
generation_config.max_length = ( | |
generation_config.max_new_tokens + input_ids_seq_length | |
) | |
if generation_config.min_new_tokens is not None: | |
generation_config.min_length = ( | |
generation_config.min_new_tokens + input_ids_seq_length | |
) | |
if input_ids_seq_length >= generation_config.max_length: | |
input_ids_string = ( | |
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
) | |
# 2. Set generation parameters if not already defined | |
logits_processor = ( | |
logits_processor if logits_processor is not None else LogitsProcessorList() | |
) | |
stopping_criteria = ( | |
stopping_criteria | |
if stopping_criteria is not None | |
else StoppingCriteriaList() | |
) | |
# 7. determine generation mode | |
is_constraint_gen_mode = ( | |
generation_config.constraints is not None or generation_config.force_words_ids is not None | |
) | |
is_contrastive_search_gen_mode = ( | |
generation_config.top_k is not None | |
and generation_config.top_k > 1 | |
and generation_config.do_sample is False | |
and generation_config.penalty_alpha is not None | |
and generation_config.penalty_alpha > 0 | |
) | |
is_greedy_gen_mode = ( | |
(generation_config.num_beams == 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is False | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
# beam=1 and do_sample=True | |
is_sample_gen_mode = ( | |
(generation_config.num_beams == 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is True | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_beam_gen_mode = ( | |
(generation_config.num_beams > 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is False | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_beam_sample_gen_mode = ( | |
(generation_config.num_beams > 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is True | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_group_beam_gen_mode = ( | |
(generation_config.num_beams > 1) | |
and (generation_config.num_beam_groups > 1) | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
# 8. prepare distribution pre_processing samplers | |
logits_processor = self._get_logits_processor( | |
generation_config=generation_config, | |
input_ids_seq_length=input_ids_seq_length, | |
encoder_input_ids=input_ids, | |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
logits_processor=logits_processor, | |
) | |
# 9. prepare stopping criteria | |
stopping_criteria = self._get_stopping_criteria( | |
generation_config=generation_config, stopping_criteria=stopping_criteria | |
) | |
logits_warper = self._get_logits_warper(generation_config) | |
if is_greedy_gen_mode: | |
# 11. run greedy search | |
return self.greedy_search( | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
) | |
elif is_sample_gen_mode: | |
# 12. expand input_ids with `num_return_sequences` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=generation_config.num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
return self.stream_sample( | |
generation_config, | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_gen_mode: | |
return self.beam_search( | |
generation_config, | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_sample_gen_mode: | |
# interleave input_ids with `num_beams` additional sequences per batch | |
return self.beam_sample( | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
) | |
else: | |
raise Exception('not implement') | |
def stream_sample( | |
self, | |
generation_config, | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
): | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
this_peer_finished = False # used by synced_gpus only | |
scores=() | |
# auto-regressive generation | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# sample | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
yield input_ids | |
# torch.cuda.empty_cache() | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
) | |
# stop when each sentence is finished, or if we exceed the maximum length | |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
return input_ids | |
def empty_cache(self): | |
torch.cuda.empty_cache() | |
def beam_sample( | |
self, | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
): | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
num_beams = generation_config.num_beams | |
batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1] | |
beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=generation_config.num_beams, | |
device=input_ids.device, | |
length_penalty=generation_config.length_penalty, | |
do_early_stopping=generation_config.early_stopping, | |
num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
max_length=generation_config.max_length, | |
) | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=generation_config.num_beams * generation_config.num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
scores = () | |
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
# cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * num_beams, vocab_size) | |
next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) | |
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers | |
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see | |
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) | |
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) | |
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) | |
next_tokens = torch.gather(next_tokens, -1, _indices) | |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
beam_indices=None, | |
) | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
yield input_ids | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past_key_values"] is not None: | |
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
beam_indices=None, | |
) | |
yield sequence_outputs["sequences"] | |
def greedy_search( | |
self, | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
): | |
# init values | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
# init attention / hidden states / scores tuples | |
scores = () | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_tokens_scores = logits_processor(input_ids, next_token_logits) | |
# argmax | |
next_tokens = torch.argmax(next_tokens_scores, dim=-1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
yield input_ids | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
) | |
# stop when each sentence is finished, or if we exceed the maximum length | |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
yield input_ids | |
def beam_search( | |
self, | |
generation_config, | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
): | |
# 10. go into beam search generation modes | |
# 11. prepare beam search scorer | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
num_beams = generation_config.num_beams | |
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=generation_config.num_beams, | |
device=input_ids.device, | |
length_penalty=generation_config.length_penalty, | |
do_early_stopping=generation_config.early_stopping, | |
num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
max_length=generation_config.max_length, | |
) | |
# 12. interleave input_ids with `num_beams` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=generation_config.num_beams, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# beam_search logits | |
batch_beam_size, cur_len = input_ids.shape | |
if num_beams * batch_size != batch_beam_size: | |
raise ValueError( | |
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
) | |
beam_scores = torch.zeros( | |
(batch_size, num_beams), dtype=torch.float, device=input_ids.device | |
) | |
beam_scores[:, 1:] = -1e9 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor( | |
0.0 if this_peer_finished else 1.0 | |
).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=False, | |
output_hidden_states=False, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian. | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * num_beams, vocab_size) | |
next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores_processed + beam_scores[ | |
:, None | |
].expand_as(next_token_scores) | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view( | |
batch_size, num_beams * vocab_size | |
) | |
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
) | |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
beam_indices=None, | |
) | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat( | |
[input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 | |
) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past_key_values"] is not None: | |
model_kwargs["past_key_values"] = self._reorder_cache( | |
model_kwargs["past_key_values"], beam_idx | |
) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
yield input_ids | |
if beam_scorer.is_done or stopping_criteria(input_ids, None): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
final_result = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
beam_indices=None, | |
) | |
yield final_result["sequences"] | |