lhoestq HF staff commited on
Commit
2e2e9ca
1 Parent(s): 82dc3c2

batched generation

Browse files
Files changed (2) hide show
  1. generate.py +27 -6
  2. gradio_app.py +3 -3
generate.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import logging
4
  import regex
5
  import time
 
6
  from pathlib import Path
7
  from typing import Annotated, Iterator
8
 
@@ -22,14 +23,16 @@ logger = logging.getLogger(__name__)
22
 
23
 
24
  logger.warning("Loading model...")
25
- model_id = "google/gemma-2b-it"
26
- # model_id = "Qwen/Qwen1.5-0.5B-Chat"
27
  if torch.backends.mps.is_available():
28
  device = "mps"
29
- model = models.transformers(model_id, device=device)
 
30
  else:
31
  device = "cuda"
32
- model = models.transformers(model_id, device=device)
 
 
 
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
  sampler = PenalizedMultinomialSampler()
@@ -95,6 +98,23 @@ def samples_prommpt(filename: str, prompt: str, columns: str):
95
  {{ prompt }}
96
  """
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
99
  filename = Path(filename).stem
100
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
@@ -134,7 +154,8 @@ def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int,
134
  tokenize=False,
135
  add_generation_prompt=True
136
  )
137
- samples_generator_tokens = samples_generator.stream(text, rng=rng)
138
- for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
 
139
  yield json.dumps(sample, ensure_ascii=False) + "\n"
140
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
 
3
  import logging
4
  import regex
5
  import time
6
+ from itertools import chain, islice
7
  from pathlib import Path
8
  from typing import Annotated, Iterator
9
 
 
23
 
24
 
25
  logger.warning("Loading model...")
 
 
26
  if torch.backends.mps.is_available():
27
  device = "mps"
28
+ model_id = "Qwen/Qwen1.5-0.5B-Chat"
29
+ batch_size = 4
30
  else:
31
  device = "cuda"
32
+ model_id = "google/gemma-2b-it"
33
+ batch_size = 20
34
+
35
+ model = models.transformers(model_id, device=device)
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  sampler = PenalizedMultinomialSampler()
 
98
  {{ prompt }}
99
  """
100
 
101
+
102
+ def stream_json_objects_from_batched_tokens_generator(batched_tokens_generator: Iterator[list[str]], json_field: str) -> Iterator[dict]:
103
+ first_batch = next(batched_tokens_generator)
104
+ batch_size = len(first_batch)
105
+ streams = [""] * batch_size
106
+ skips = [0] * batch_size
107
+ for tokens_batch in chain([first_batch], batched_tokens_generator):
108
+ for stream_idx, token in enumerate(tokens_batch):
109
+ streams[stream_idx] += token
110
+ try:
111
+ for stream_sample in islice(ijson.items(StringIteratorIO(streams[stream_idx].__iter__()), json_field + ".item", buf_size=1), skips[stream_idx], None):
112
+ yield stream_sample
113
+ skips[stream_idx] = +1
114
+ except ijson.IncompleteJSONError:
115
+ pass
116
+
117
+
118
  def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
119
  filename = Path(filename).stem
120
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
 
154
  tokenize=False,
155
  add_generation_prompt=True
156
  )
157
+ batched_samples_generator_tokens = samples_generator.stream([text] * batch_size, rng=rng)
158
+ json_field = list(Dataset.model_fields)[0]
159
+ for _, sample in zip(range(size), stream_json_objects_from_batched_tokens_generator(batched_samples_generator_tokens, json_field=json_field)):
160
  yield json.dumps(sample, ensure_ascii=False) + "\n"
161
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
gradio_app.py CHANGED
@@ -6,11 +6,11 @@ import io
6
  import pandas as pd
7
  import spaces
8
 
9
- from generate import model_id, stream_jsonl_file
10
 
11
- MAX_SIZE = 20
12
  DEFAULT_SEED = 42
13
- DEFAULT_SIZE = 3
14
 
15
  @spaces.GPU(duration=120)
16
  def stream_output(query: str, continue_content: str = ""):
 
6
  import pandas as pd
7
  import spaces
8
 
9
+ from generate import model_id, stream_jsonl_file, batch_size
10
 
11
+ MAX_SIZE = 20 * batch_size
12
  DEFAULT_SEED = 42
13
+ DEFAULT_SIZE = 5 * batch_size
14
 
15
  @spaces.GPU(duration=120)
16
  def stream_output(query: str, continue_content: str = ""):