Spaces:
Runtime error
Runtime error
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
--> | |
# Text generation strategies | |
Text generation is essential to many NLP tasks, such as open-ended text generation, summarization, translation, and | |
more. It also plays a role in a variety of mixed-modality applications that have text as an output like speech-to-text | |
and vision-to-text. Some of the models that can generate text include | |
GPT2, XLNet, OpenAI GPT, CTRL, TransformerXL, XLM, Bart, T5, GIT, Whisper. | |
Check out a few examples that use [`~transformers.generation_utils.GenerationMixin.generate`] method to produce | |
text outputs for different tasks: | |
* [Text summarization](./tasks/summarization#inference) | |
* [Image captioning](./model_doc/git#transformers.GitForCausalLM.forward.example) | |
* [Audio transcription](./model_doc/whisper#transformers.WhisperForConditionalGeneration.forward.example) | |
Note that the inputs to the generate method depend on the model's modality. They are returned by the model's preprocessor | |
class, such as AutoTokenizer or AutoProcessor. If a model's preprocessor creates more than one kind of input, pass all | |
the inputs to generate(). You can learn more about the individual model's preprocessor in the corresponding model's documentation. | |
The process of selecting output tokens to generate text is known as decoding, and you can customize the decoding strategy | |
that the `generate()` method will use. Modifying a decoding strategy does not change the values of any trainable parameters. | |
However, it can have a noticeable impact on the quality of the generated output. It can help reduce repetition in the text | |
and make it more coherent. | |
This guide describes: | |
* default generation configuration | |
* common decoding strategies and their main parameters | |
* saving and sharing custom generation configurations with your fine-tuned model on 🤗 Hub | |
## Default text generation configuration | |
A decoding strategy for a model is defined in its generation configuration. When using pre-trained models for inference | |
within a [`pipeline`], the models call the `PreTrainedModel.generate()` method that applies a default generation | |
configuration under the hood. The default configuration is also used when no custom configuration has been saved with | |
the model. | |
When you load a model explicitly, you can inspect the generation configuration that comes with it through | |
`model.generation_config`: | |
```python | |
>>> from transformers import AutoModelForCausalLM | |
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
>>> model.generation_config | |
GenerationConfig { | |
"_from_model_config": true, | |
"bos_token_id": 50256, | |
"eos_token_id": 50256, | |
"transformers_version": "4.26.0.dev0" | |
} | |
``` | |
Printing out the `model.generation_config` reveals only the values that are different from the default generation | |
configuration, and does not list any of the default values. | |
The default generation configuration limits the size of the output combined with the input prompt to a maximum of 20 | |
tokens to avoid running into resource limitations. The default decoding strategy is greedy search, which is the simplest decoding strategy that picks a token with the highest probability as the next token. For many tasks | |
and small output sizes this works well. However, when used to generate longer outputs, greedy search can start | |
producing highly repetitive results. | |
## Customize text generation | |
You can override any `generation_config` by passing the parameters and their values directly to the [`generate`] method: | |
```python | |
>>> my_model.generate(**inputs, num_beams=4, do_sample=True) | |
``` | |
Even if the default decoding strategy mostly works for your task, you can still tweak a few things. Some of the | |
commonly adjusted parameters include: | |
- `max_new_tokens`: the maximum number of tokens to generate. In other words, the size of the output sequence, not | |
including the tokens in the prompt. | |
- `num_beams`: by specifying a number of beams higher than 1, you are effectively switching from greedy search to | |
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that | |
has the overall highest probability for the entire sequence. This has the advantage of identifying high-probability | |
sequences that start with a lower probability initial tokens and would've been ignored by the greedy search. | |
- `do_sample`: if set to `True`, this parameter enables decoding strategies such as multinomial sampling, beam-search | |
multinomial sampling, Top-K sampling and Top-p sampling. All these strategies select the next token from the probability | |
distribution over the entire vocabulary with various strategy-specific adjustments. | |
- `num_return_sequences`: the number of sequence candidates to return for each input. This options is only available for | |
the decoding strategies that support multiple sequence candidates, e.g. variations of beam search and sampling. Decoding | |
strategies like greedy search and contrastive search return a single output sequence. | |
## Save a custom decoding strategy with your model | |
If you would like to share your fine-tuned model with a specific generation configuration, you can: | |
* Create a [`GenerationConfig`] class instance | |
* Specify the decoding strategy parameters | |
* Save your generation configuration with [`GenerationConfig.save_pretrained`], making sure to leave its `config_file_name` argument empty | |
* Set `push_to_hub` to `True` to upload your config to the model's repo | |
```python | |
>>> from transformers import AutoModelForCausalLM, GenerationConfig | |
>>> model = AutoModelForCausalLM.from_pretrained("my_account/my_model") | |
>>> generation_config = GenerationConfig( | |
... max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id | |
... ) | |
>>> generation_config.save_pretrained("my_account/my_model", push_to_hub=True) | |
``` | |
You can also store several generation configurations in a single directory, making use of the `config_file_name` | |
argument in [`GenerationConfig.save_pretrained`]. You can later instantiate them with [`GenerationConfig.from_pretrained`]. This is useful if you want to | |
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and | |
one for summarization with beam search). You must have the right Hub permissions to add configuration files to a model. | |
```python | |
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") | |
>>> translation_generation_config = GenerationConfig( | |
... num_beams=4, | |
... early_stopping=True, | |
... decoder_start_token_id=0, | |
... eos_token_id=model.config.eos_token_id, | |
... pad_token=model.config.pad_token_id, | |
... ) | |
>>> translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True) | |
>>> # You could then use the named generation config file to parameterize generation | |
>>> generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json") | |
>>> inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt") | |
>>> outputs = model.generate(**inputs, generation_config=generation_config) | |
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
['Les fichiers de configuration sont faciles à utiliser !'] | |
``` | |
## Streaming | |
The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance | |
from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and | |
`end()` is used to flag the end of text generation. | |
<Tip warning={true}> | |
The API for the streamer classes is still under development and may change in the future. | |
</Tip> | |
In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes | |
ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into | |
your screen, one word at a time: | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
>>> tok = AutoTokenizer.from_pretrained("gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") | |
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") | |
>>> streamer = TextStreamer(tok) | |
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout. | |
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) | |
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, | |
``` | |
## Decoding strategies | |
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific | |
decoding strategies. If you are new to this concept, we recommend reading [this blog post that illustrates how common decoding strategies work](https://huggingface.co/blog/how-to-generate). | |
Here, we'll show some of the parameters that control the decoding strategies and illustrate how you can use them. | |
### Greedy Search | |
[`generate`] uses greedy search decoding by default so you don't have to pass any parameters to enable it. This means the parameters `num_beams` is set to 1 and `do_sample=False`. | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer | |
>>> prompt = "I look forward to" | |
>>> checkpoint = "distilgpt2" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['I look forward to seeing you all again!\n\n\n\n\n\n\n\n\n\n\n'] | |
``` | |
### Contrastive search | |
The contrastive search decoding strategy was proposed in the 2022 paper [A Contrastive Framework for Neural Text Generation](https://arxiv.org/abs/2202.06417). | |
It demonstrates superior results for generating non-repetitive yet coherent long outputs. To learn how contrastive search | |
works, check out [this blog post](https://huggingface.co/blog/introducing-csearch). | |
The two main parameters that enable and control the behavior of contrastive search are `penalty_alpha` and `top_k`: | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForCausalLM | |
>>> checkpoint = "gpt2-large" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> prompt = "Hugging Face Company is" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['Hugging Face Company is a family owned and operated business. \ | |
We pride ourselves on being the best in the business and our customer service is second to none.\ | |
\n\nIf you have any questions about our products or services, feel free to contact us at any time.\ | |
We look forward to hearing from you!'] | |
``` | |
### Multinomial sampling | |
As opposed to greedy search that always chooses a token with the highest probability as the | |
next token, multinomial sampling (also called ancestral sampling) randomly selects the next token based on the probability distribution over the entire | |
vocabulary given by the model. Every token with a non-zero probability has a chance of being selected, thus reducing the | |
risk of repetition. | |
To enable multinomial sampling set `do_sample=True` and `num_beams=1`. | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForCausalLM | |
>>> checkpoint = "gpt2-large" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> prompt = "Today was an amazing day because" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['Today was an amazing day because we are now in the final stages of our trip to New York City which was very tough. \ | |
It is a difficult schedule and a challenging part of the year but still worth it. I have been taking things easier and \ | |
I feel stronger and more motivated to be out there on their tour. Hopefully, that experience is going to help them with \ | |
their upcoming events which are currently scheduled in Australia.\n\nWe love that they are here. They want to make a \ | |
name for themselves and become famous for what they'] | |
``` | |
### Beam-search decoding | |
Unlike greedy search, beam-search decoding keeps several hypotheses at each time step and eventually chooses | |
the hypothesis that has the overall highest probability for the entire sequence. This has the advantage of identifying high-probability | |
sequences that start with lower probability initial tokens and would've been ignored by the greedy search. | |
To enable this decoding strategy, specify the `num_beams` (aka number of hypotheses to keep track of) that is greater than 1. | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer | |
>>> prompt = "It is astonishing how one can" | |
>>> checkpoint = "gpt2-medium" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs, num_beams=5, max_new_tokens=50) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['It is astonishing how one can have such a profound impact on the lives of so many people in such a short period of \ | |
time."\n\nHe added: "I am very proud of the work I have been able to do in the last few years.\n\n"I have'] | |
``` | |
### Beam-search multinomial sampling | |
As the name implies, this decoding strategy combines beam search with multinomial sampling. You need to specify | |
the `num_beams` greater than 1, and set `do_sample=True` to use this decoding strategy. | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
>>> prompt = "translate English to German: The house is wonderful." | |
>>> checkpoint = "t5-small" | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs, num_beams=5, do_sample=True) | |
>>> tokenizer.decode(outputs[0], skip_special_tokens=True) | |
'Das Haus ist wunderbar.' | |
``` | |
### Diverse beam search decoding | |
The diverse beam search decoding strategy is an extension of the beam search strategy that allows for generating a more diverse | |
set of beam sequences to choose from. To learn how it works, refer to [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf). | |
This approach has two main parameters: `num_beams` and `num_beam_groups`. | |
The groups are selected to ensure they are distinct enough compared to the others, and regular beam search is used within each group. | |
```python | |
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
>>> checkpoint = "google/pegasus-xsum" | |
>>> prompt = "The Permaculture Design Principles are a set of universal design principles \ | |
>>> that can be applied to any location, climate and culture, and they allow us to design \ | |
>>> the most efficient and sustainable human habitation and food production systems. \ | |
>>> Permaculture is a design system that encompasses a wide variety of disciplines, such \ | |
>>> as ecology, landscape design, environmental science and energy conservation, and the \ | |
>>> Permaculture design principles are drawn from these various disciplines. Each individual \ | |
>>> design principle itself embodies a complete conceptual framework based on sound \ | |
>>> scientific principles. When we bring all these separate principles together, we can \ | |
>>> create a design system that both looks at whole systems, the parts that these systems \ | |
>>> consist of, and how those parts interact with each other to create a complex, dynamic, \ | |
>>> living system. Each design principle serves as a tool that allows us to integrate all \ | |
>>> the separate parts of a design, referred to as elements, into a functional, synergistic, \ | |
>>> whole system, where the elements harmoniously interact and work together in the most \ | |
>>> efficient way possible." | |
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30) | |
>>> tokenizer.decode(outputs[0], skip_special_tokens=True) | |
'The Design Principles are a set of universal design principles that can be applied to any location, climate and culture, and they allow us to design the most efficient and sustainable human habitation and food production systems.' | |
``` | |
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the | |
[`generate`] method, which gives you even further control over the [`generate`] method's behavior. | |
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx). | |