Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2023 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from queue import Queue | |
from typing import TYPE_CHECKING, Optional | |
if TYPE_CHECKING: | |
from ..models.auto import AutoTokenizer | |
class BaseStreamer: | |
""" | |
Base class from which `.generate()` streamers should inherit. | |
""" | |
def put(self, value): | |
"""Function that is called by `.generate()` to push new tokens""" | |
raise NotImplementedError() | |
def end(self): | |
"""Function that is called by `.generate()` to signal the end of generation""" | |
raise NotImplementedError() | |
class TextStreamer(BaseStreamer): | |
""" | |
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. | |
<Tip warning={true}> | |
The API for the streamer classes is still under development and may change in the future. | |
</Tip> | |
Parameters: | |
tokenizer (`AutoTokenizer`): | |
The tokenized used to decode the tokens. | |
skip_prompt (`bool`, *optional*, defaults to `False`): | |
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. | |
decode_kwargs (`dict`, *optional*): | |
Additional keyword arguments to pass to the tokenizer's `decode` method. | |
Examples: | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
>>> tok = AutoTokenizer.from_pretrained("gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") | |
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") | |
>>> streamer = TextStreamer(tok) | |
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout. | |
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) | |
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, | |
``` | |
""" | |
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): | |
self.tokenizer = tokenizer | |
self.skip_prompt = skip_prompt | |
self.decode_kwargs = decode_kwargs | |
# variables used in the streaming process | |
self.token_cache = [] | |
self.print_len = 0 | |
self.next_tokens_are_prompt = True | |
def put(self, value): | |
""" | |
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. | |
""" | |
if len(value.shape) > 1 and value.shape[0] > 1: | |
raise ValueError("TextStreamer only supports batch size 1") | |
elif len(value.shape) > 1: | |
value = value[0] | |
if self.skip_prompt and self.next_tokens_are_prompt: | |
self.next_tokens_are_prompt = False | |
return | |
# Add the new token to the cache and decodes the entire thing. | |
self.token_cache.extend(value.tolist()) | |
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) | |
# After the symbol for a new line, we flush the cache. | |
if text.endswith("\n"): | |
printable_text = text[self.print_len :] | |
self.token_cache = [] | |
self.print_len = 0 | |
# If the last token is a CJK character, we print the characters. | |
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): | |
printable_text = text[self.print_len :] | |
self.print_len += len(printable_text) | |
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, | |
# which may change with the subsequent token -- there are probably smarter ways to do this!) | |
else: | |
printable_text = text[self.print_len : text.rfind(" ") + 1] | |
self.print_len += len(printable_text) | |
self.on_finalized_text(printable_text) | |
def end(self): | |
"""Flushes any remaining cache and prints a newline to stdout.""" | |
# Flush the cache, if it exists | |
if len(self.token_cache) > 0: | |
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) | |
printable_text = text[self.print_len :] | |
self.token_cache = [] | |
self.print_len = 0 | |
else: | |
printable_text = "" | |
self.next_tokens_are_prompt = True | |
self.on_finalized_text(printable_text, stream_end=True) | |
def on_finalized_text(self, text: str, stream_end: bool = False): | |
"""Prints the new text to stdout. If the stream is ending, also prints a newline.""" | |
print(text, flush=True, end="" if not stream_end else None) | |
def _is_chinese_char(self, cp): | |
"""Checks whether CP is the codepoint of a CJK character.""" | |
# This defines a "chinese character" as anything in the CJK Unicode block: | |
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |
# | |
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |
# despite its name. The modern Korean Hangul alphabet is a different block, | |
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | |
# space-separated words, so they are not treated specially and handled | |
# like the all of the other languages. | |
if ( | |
(cp >= 0x4E00 and cp <= 0x9FFF) | |
or (cp >= 0x3400 and cp <= 0x4DBF) # | |
or (cp >= 0x20000 and cp <= 0x2A6DF) # | |
or (cp >= 0x2A700 and cp <= 0x2B73F) # | |
or (cp >= 0x2B740 and cp <= 0x2B81F) # | |
or (cp >= 0x2B820 and cp <= 0x2CEAF) # | |
or (cp >= 0xF900 and cp <= 0xFAFF) | |
or (cp >= 0x2F800 and cp <= 0x2FA1F) # | |
): # | |
return True | |
return False | |
class TextIteratorStreamer(TextStreamer): | |
""" | |
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is | |
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive | |
Gradio demo). | |
<Tip warning={true}> | |
The API for the streamer classes is still under development and may change in the future. | |
</Tip> | |
Parameters: | |
tokenizer (`AutoTokenizer`): | |
The tokenized used to decode the tokens. | |
skip_prompt (`bool`, *optional*, defaults to `False`): | |
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. | |
timeout (`float`, *optional*): | |
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions | |
in `.generate()`, when it is called in a separate thread. | |
decode_kwargs (`dict`, *optional*): | |
Additional keyword arguments to pass to the tokenizer's `decode` method. | |
Examples: | |
```python | |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
>>> from threading import Thread | |
>>> tok = AutoTokenizer.from_pretrained("gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") | |
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") | |
>>> streamer = TextIteratorStreamer(tok) | |
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. | |
>>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) | |
>>> thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
>>> thread.start() | |
>>> generated_text = "" | |
>>> for new_text in streamer: | |
... generated_text += new_text | |
>>> generated_text | |
'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,' | |
``` | |
""" | |
def __init__( | |
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs | |
): | |
super().__init__(tokenizer, skip_prompt, **decode_kwargs) | |
self.text_queue = Queue() | |
self.stop_signal = None | |
self.timeout = timeout | |
def on_finalized_text(self, text: str, stream_end: bool = False): | |
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" | |
self.text_queue.put(text, timeout=self.timeout) | |
if stream_end: | |
self.text_queue.put(self.stop_signal, timeout=self.timeout) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
value = self.text_queue.get(timeout=self.timeout) | |
if value == self.stop_signal: | |
raise StopIteration() | |
else: | |
return value | |