alonsosilva commited on
Commit
f304ebf
1 Parent(s): 0c8ab93

Use TextIteratorStreamer instead of custom Streamer

Browse files
Files changed (1) hide show
  1. app.py +15 -145
app.py CHANGED
@@ -1,149 +1,9 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
 
3
  model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
4
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
5
-
6
- class BaseStreamer:
7
- """
8
- Base class from which `.generate()` streamers should inherit.
9
- """
10
-
11
- def put(self, value):
12
- """Function that is called by `.generate()` to push new tokens"""
13
- raise NotImplementedError()
14
-
15
- def end(self):
16
- """Function that is called by `.generate()` to signal the end of generation"""
17
- raise NotImplementedError()
18
-
19
- class TextStreamer(BaseStreamer):
20
- """
21
- Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
22
-
23
- <Tip warning={true}>
24
-
25
- The API for the streamer classes is still under development and may change in the future.
26
-
27
- </Tip>
28
-
29
- Parameters:
30
- tokenizer (`AutoTokenizer`):
31
- The tokenized used to decode the tokens.
32
- skip_prompt (`bool`, *optional*, defaults to `False`):
33
- Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
34
- decode_kwargs (`dict`, *optional*):
35
- Additional keyword arguments to pass to the tokenizer's `decode` method.
36
-
37
- Examples:
38
-
39
- ```python
40
- >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
41
-
42
- >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
43
- >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
44
- >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
45
- >>> streamer = TextStreamer(tok)
46
-
47
- >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
48
- >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
49
- An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
50
- ```
51
- """
52
-
53
- def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
54
- self.tokenizer = tokenizer
55
- self.skip_prompt = skip_prompt
56
- self.decode_kwargs = decode_kwargs
57
-
58
- # variables used in the streaming process
59
- self.token_cache = []
60
- self.print_len = 0
61
- self.next_tokens_are_prompt = True
62
-
63
- def put(self, value):
64
- """
65
- Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
66
- """
67
- if len(value.shape) > 1 and value.shape[0] > 1:
68
- raise ValueError("TextStreamer only supports batch size 1")
69
- elif len(value.shape) > 1:
70
- value = value[0]
71
-
72
- if self.skip_prompt and self.next_tokens_are_prompt:
73
- self.next_tokens_are_prompt = False
74
- return
75
-
76
- # Add the new token to the cache and decodes the entire thing.
77
- self.token_cache.extend(value.tolist())
78
- text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
79
-
80
- # After the symbol for a new line, we flush the cache.
81
- if text.endswith("\n"):
82
- printable_text = text[self.print_len :]
83
- self.token_cache = []
84
- self.print_len = 0
85
- # If the last token is a CJK character, we print the characters.
86
- elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
87
- printable_text = text[self.print_len :]
88
- self.print_len += len(printable_text)
89
- # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
90
- # which may change with the subsequent token -- there are probably smarter ways to do this!)
91
- else:
92
- printable_text = text[self.print_len : text.rfind(" ") + 1]
93
- self.print_len += len(printable_text)
94
-
95
- self.on_finalized_text(printable_text)
96
-
97
- def end(self):
98
- """Flushes any remaining cache and prints a newline to stdout."""
99
- # Flush the cache, if it exists
100
- if len(self.token_cache) > 0:
101
- text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
102
- printable_text = text[self.print_len :]
103
- self.token_cache = []
104
- self.print_len = 0
105
- else:
106
- printable_text = ""
107
-
108
- self.next_tokens_are_prompt = True
109
- self.on_finalized_text(printable_text, stream_end=True)
110
-
111
- def on_finalized_text(self, text: str, stream_end: bool = False):
112
- """Prints the new text to stdout. If the stream is ending, also prints a newline."""
113
- # print(text, flush=True, end="" if not stream_end else None)
114
- messages.value = [
115
- *messages.value[:-1],
116
- {
117
- "role": "assistant",
118
- "content": messages.value[-1]["content"] + text,
119
- },
120
- ]
121
-
122
- def _is_chinese_char(self, cp):
123
- """Checks whether CP is the codepoint of a CJK character."""
124
- # This defines a "chinese character" as anything in the CJK Unicode block:
125
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
126
- #
127
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
128
- # despite its name. The modern Korean Hangul alphabet is a different block,
129
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
130
- # space-separated words, so they are not treated specially and handled
131
- # like the all of the other languages.
132
- if (
133
- (cp >= 0x4E00 and cp <= 0x9FFF)
134
- or (cp >= 0x3400 and cp <= 0x4DBF) #
135
- or (cp >= 0x20000 and cp <= 0x2A6DF) #
136
- or (cp >= 0x2A700 and cp <= 0x2B73F) #
137
- or (cp >= 0x2B740 and cp <= 0x2B81F) #
138
- or (cp >= 0x2B820 and cp <= 0x2CEAF) #
139
- or (cp >= 0xF900 and cp <= 0xFAFF)
140
- or (cp >= 0x2F800 and cp <= 0x2FA1F) #
141
- ): #
142
- return True
143
-
144
- return False
145
-
146
- streamer = TextStreamer(tokenizer, skip_prompt=True)
147
 
148
  import re
149
  import solara
@@ -161,7 +21,7 @@ def Page():
161
  solara.lab.theme.themes.light.secondary = "#0000ff"
162
  solara.lab.theme.themes.dark.primary = "#0000ff"
163
  solara.lab.theme.themes.dark.secondary = "#0000ff"
164
- title = "Qwen2-0.5B"
165
  with solara.Head():
166
  solara.Title(f"{title}")
167
  with solara.Column(align="center"):
@@ -176,7 +36,17 @@ def Page():
176
  add_generation_prompt=True
177
  )
178
  inputs = tokenizer(text, return_tensors="pt")
179
- _ = model.generate(**inputs, streamer=streamer, max_new_tokens=512)
 
 
 
 
 
 
 
 
 
 
180
  def result():
181
  if messages.value != []:
182
  response(messages.value[-1]["content"])
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
+ from threading import Thread
3
 
4
  model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
5
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
6
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import re
9
  import solara
 
21
  solara.lab.theme.themes.light.secondary = "#0000ff"
22
  solara.lab.theme.themes.dark.primary = "#0000ff"
23
  solara.lab.theme.themes.dark.secondary = "#0000ff"
24
+ title = "Qwen2-0.5B-Instruct"
25
  with solara.Head():
26
  solara.Title(f"{title}")
27
  with solara.Column(align="center"):
 
36
  add_generation_prompt=True
37
  )
38
  inputs = tokenizer(text, return_tensors="pt")
39
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
40
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
41
+ thread.start()
42
+ for text in streamer:
43
+ messages.value = [
44
+ *messages.value[:-1],
45
+ {
46
+ "role": "assistant",
47
+ "content": messages.value[-1]["content"] + text,
48
+ },
49
+ ]
50
  def result():
51
  if messages.value != []:
52
  response(messages.value[-1]["content"])