masharpe commited on
Commit
bfcf71e
β€’
1 Parent(s): 4d21997

Space for all-prefix-optimal coupling

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. README.md +10 -4
  3. algorithms.py +542 -0
  4. app.py +159 -0
  5. load.py +61 -0
  6. requirements.txt +4 -0
  7. util.py +15 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -1,12 +1,18 @@
1
  ---
2
- title: All Prefix Optimal Coupling Demo 1
3
- emoji: πŸ“Š
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
 
7
  sdk_version: 4.41.0
8
  app_file: app.py
 
 
 
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: All-Prefix-Optimal Coupling
3
+ emoji: πŸ”—
4
+ colorFrom: gray
5
+ colorTo: gray
6
  sdk: gradio
7
+ python_version: 3.11
8
  sdk_version: 4.41.0
9
  app_file: app.py
10
+ short_description: Tightly pair LLM responses
11
+ models:
12
+ - Qwen/Qwen2-0.5B-Instruct
13
  pinned: false
14
+ preload_from_hub:
15
+ - Qwen/Qwen2-0.5B-Instruct
16
  ---
17
 
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
algorithms.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Device-independent algorithms for LLM.
2
+
3
+ import logging
4
+ import time
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Helper to pull out the response tokens.
11
+ def isolate_responses_BL(output_BL, prompt_len, eos_token_id):
12
+ responses_BL = []
13
+ resp_BL = output_BL[:, prompt_len:]
14
+ for i in range(resp_BL.shape[0]):
15
+ resp_L = resp_BL[i]
16
+ resplen = resp_L.shape[0]
17
+ for j in range(resplen):
18
+ if resp_L[j] == eos_token_id:
19
+ resplen = j+1
20
+ break
21
+ response_L = resp_L[:resplen].cpu().detach().numpy()
22
+ responses_BL.append(response_L)
23
+ return responses_BL
24
+
25
+ def tokenize_prompt(device, tokenizer, chat, quiet=False):
26
+ # Tokenize the prompt.
27
+ prompt_BL = tokenizer.apply_chat_template(
28
+ [chat],
29
+ tokenizer=True, add_generation_prompt=True, return_tensors='pt'
30
+ ).to(device)
31
+ if not quiet:
32
+ print('PROMPT:')
33
+ print(tokenizer.decode(prompt_BL[0]))
34
+ return prompt_BL
35
+
36
+ def generate(device, model, tokenizer, chat):
37
+ """Generate a response using huggingface's generation."""
38
+
39
+ prompt_BL = tokenize_prompt(device, tokenizer, chat)
40
+ prompt_len = prompt_BL.shape[1]
41
+
42
+ # Generate response.
43
+
44
+ # Unfortunately, huggingface's generation code uses 'cumsum',
45
+ # which doesn't have a deterministic implementation.
46
+ torch.use_deterministic_algorithms(False)
47
+ generation_output = model.generate(
48
+ inputs=prompt_BL, max_new_tokens=512, do_sample=True,
49
+ return_dict_in_generate=True,
50
+ )
51
+ torch.use_deterministic_algorithms(True)
52
+
53
+ output_BL = generation_output.sequences
54
+
55
+ for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id):
56
+ print(f'RESPONSE:')
57
+ print(tokenizer.decode(response_L))
58
+
59
+ return tokenizer.decode(response_L)
60
+
61
+ def generate_with_logits(device, model, tokenizer, chat, seed=None):
62
+ # Huggingface generation that returns logits too.
63
+
64
+ if seed is not None:
65
+ transformers.set_seed(seed)
66
+
67
+ prompt_BL = tokenize_prompt(device, tokenizer, chat)
68
+ prompt_len = prompt_BL.shape[1]
69
+
70
+ # Unfortunately, huggingface's generation code uses 'cumsum',
71
+ # which doesn't have a deterministic implementation.
72
+ torch.use_deterministic_algorithms(False)
73
+ generation_output = model.generate(
74
+ inputs=prompt_BL, max_new_tokens=16, do_sample=True,
75
+ return_dict_in_generate=True, output_logits=True,
76
+ )
77
+ torch.use_deterministic_algorithms(True)
78
+
79
+ output_BL = generation_output.sequences
80
+ logits_BLV = torch.stack(generation_output.logits, axis=1)
81
+
82
+ for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id):
83
+ print(f'RESPONSE:')
84
+ print(tokenizer.decode(response_L))
85
+
86
+ return response_L, logits_BLV[0]
87
+
88
+ def response_logits(device, model, tokenizer, chat, response_L):
89
+ # Calculate logits using a single pass.
90
+
91
+ prompt_BL = tokenize_prompt(device, tokenizer, chat)
92
+ response_pt_L = torch.from_numpy(response_L[:-1])
93
+
94
+ # Concatenate along axis 1.
95
+ input_ids = torch.cat((prompt_BL, response_pt_L[None]), dim=1)
96
+
97
+ outputs = model(input_ids)
98
+ logits_BLV = outputs.logits
99
+
100
+ return logits_BLV[0][-len(response_L):]
101
+
102
+ def generate_custom(device, model, tokenizer, chat, max_tokens=512, seed=None, return_tokens=False, quiet=False, return_lnprobs=False):
103
+ """Generate a response using custom generation."""
104
+
105
+ if seed is not None:
106
+ torch.manual_seed(seed)
107
+
108
+ prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=quiet)
109
+
110
+ # Generate response.
111
+
112
+ input_ids = prompt_BL
113
+ past_key_values = None
114
+
115
+ response_L = []
116
+ lnprobs_L = []
117
+ while 1:
118
+ outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
119
+ past_key_values = outputs.past_key_values
120
+
121
+ logits = outputs.logits[0, -1, :]
122
+ p_V = F.softmax(logits, dim=-1)
123
+ token = torch.multinomial(p_V, num_samples=1).item()
124
+ response_L.append(token)
125
+ lnprobs_L.append(F.log_softmax(logits, dim=-1)[token].item())
126
+
127
+ p_token = p_V[token].item()
128
+ if not quiet:
129
+ print(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
130
+
131
+ if token == tokenizer.eos_token_id or len(response_L) >= max_tokens:
132
+ break
133
+
134
+ input_ids = torch.tensor([[token]], device=device)
135
+
136
+ if not quiet:
137
+ print(f'RESPONSE:')
138
+ print(tokenizer.decode(response_L))
139
+
140
+ if return_tokens:
141
+ if return_lnprobs:
142
+ return response_L, lnprobs_L
143
+ else:
144
+ return response_L
145
+ else:
146
+ assert not return_lnprobs
147
+ return tokenizer.decode(response_L)
148
+
149
+ def format_token(tokenizer, token_id):
150
+ return repr(tokenizer.decode(token_id))
151
+
152
+ def apoc(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None, return_tokens=False, quiet=False):
153
+ """Generate a response using APOC unconditional sampling."""
154
+
155
+ if seed is not None:
156
+ torch.manual_seed(seed)
157
+
158
+ # This early implementation of the algorithm is numerically non-robust,
159
+ # so reduce problems by using high-precision floating-point.
160
+ logit_dtype = torch.float64
161
+
162
+ prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=quiet)
163
+ prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=quiet)
164
+
165
+ # Four variables are needed, since in the first iteration it depends on prompt (X vs Y),
166
+ # whereas in later iterations it depends on response (a vs b).
167
+ input_ids_xa = prompt_x_BL
168
+ input_ids_ya = prompt_y_BL
169
+ input_ids_xb = prompt_x_BL
170
+ input_ids_yb = prompt_y_BL
171
+
172
+ past_key_values_xa = None
173
+ past_key_values_ya = None
174
+ past_key_values_xb = None
175
+ past_key_values_yb = None
176
+
177
+ equal = True
178
+ a_eos = False
179
+ b_eos = False
180
+ def zero():
181
+ return torch.zeros(1, dtype=logit_dtype, device=device)
182
+ ln_pya_m_ln_pxa = zero()
183
+ ln_pxb_m_ln_pyb = zero()
184
+
185
+ response_a_L = []
186
+ response_b_L = []
187
+ i = 0
188
+ while 1:
189
+ if i >= max_tokens or (a_eos and b_eos): break
190
+ if not quiet:
191
+ print(f'Generating response token {i}')
192
+ i += 1
193
+
194
+ forward_passes_start = time.perf_counter()
195
+
196
+ if not a_eos:
197
+ outputs = model_x(input_ids_xa, past_key_values=past_key_values_xa, use_cache=True)
198
+ past_key_values_xa = outputs.past_key_values
199
+ logits = outputs.logits[0, -1, :].to(logit_dtype)
200
+ ln_pxa_V = F.log_softmax(logits, dim=-1)
201
+
202
+ outputs = model_y(input_ids_ya, past_key_values=past_key_values_ya, use_cache=True)
203
+ past_key_values_ya = outputs.past_key_values
204
+ logits = outputs.logits[0, -1, :].to(logit_dtype)
205
+ ln_pya_V = F.log_softmax(logits, dim=-1)
206
+
207
+ if not b_eos:
208
+ if equal:
209
+ # In equal mode, neither input_ids nor past_key_values depends on a vs b,
210
+ # so we can reuse the forward pass results for a 50% time savings.
211
+ assert not a_eos
212
+ past_key_values_xb = past_key_values_xa
213
+ ln_pxb_V = ln_pxa_V
214
+ past_key_values_yb = past_key_values_ya
215
+ ln_pyb_V = ln_pya_V
216
+ else:
217
+ outputs = model_x(input_ids_xb, past_key_values=past_key_values_xb, use_cache=True)
218
+ past_key_values_xb = outputs.past_key_values
219
+ logits = outputs.logits[0, -1, :].to(logit_dtype)
220
+ ln_pxb_V = F.log_softmax(logits, dim=-1)
221
+
222
+ outputs = model_y(input_ids_yb, past_key_values=past_key_values_yb, use_cache=True)
223
+ past_key_values_yb = outputs.past_key_values
224
+ logits = outputs.logits[0, -1, :].to(logit_dtype)
225
+ ln_pyb_V = F.log_softmax(logits, dim=-1)
226
+
227
+ forward_passes_end = time.perf_counter()
228
+ if not quiet:
229
+ print(f' Forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms')
230
+
231
+ if equal:
232
+ ln_pmeet_V = torch.minimum(
233
+ ln_pxa_V + torch.maximum(zero(), -ln_pya_m_ln_pxa),
234
+ ln_pya_V + torch.maximum(zero(), ln_pya_m_ln_pxa),
235
+ )
236
+ pmeet_V = torch.exp(ln_pmeet_V)
237
+ pmeet = torch.sum(pmeet_V)
238
+
239
+ if not quiet:
240
+ print(f' Equal mode (pmeet={pmeet*100.0:.3f}%)')
241
+
242
+ if torch.rand_like(pmeet) < pmeet:
243
+ token_a = token_b = torch.multinomial(pmeet_V, num_samples=1).item()
244
+
245
+ p_token_a = (pmeet_V[token_a] / pmeet).item()
246
+ if not quiet:
247
+ print(f' Sampled {format_token(tokenizer, token_a)} ({p_token_a*100.0:.1f}%)')
248
+ else:
249
+ if not quiet:
250
+ print(' Exited equal mode')
251
+ equal = False
252
+
253
+ if not equal:
254
+ if not a_eos:
255
+ wxt_V = torch.maximum(zero(), torch.exp(ln_pxa_V) - torch.exp(ln_pya_V + ln_pya_m_ln_pxa))
256
+ token_a = torch.multinomial(wxt_V, num_samples=1).item()
257
+
258
+ p_token_a = (wxt_V[token_a] / torch.sum(wxt_V)).item()
259
+ if not quiet:
260
+ print(f' Sampled token_a {format_token(tokenizer, token_a)} ({p_token_a*100.0:.3f}%)')
261
+ if not b_eos:
262
+ wyt_V = torch.maximum(zero(), torch.exp(ln_pyb_V) - torch.exp(ln_pxb_V + ln_pxb_m_ln_pyb))
263
+ token_b = torch.multinomial(wyt_V, num_samples=1).item()
264
+
265
+ p_token_b = (wyt_V[token_b] / torch.sum(wyt_V)).item()
266
+ if not quiet:
267
+ print(f' Sampled token_b {format_token(tokenizer, token_b)} ({p_token_b*100.0:.3f}%)')
268
+
269
+ if not a_eos:
270
+ response_a_L.append(token_a)
271
+ input_ids_xa = input_ids_ya = torch.tensor([[token_a]], device=device)
272
+ ln_pya_m_ln_pxa += ln_pya_V[token_a] - ln_pxa_V[token_a]
273
+ if token_a == tokenizer.eos_token_id:
274
+ a_eos = True
275
+ if not b_eos:
276
+ response_b_L.append(token_b)
277
+ input_ids_xb = input_ids_yb = torch.tensor([[token_b]], device=device)
278
+ ln_pxb_m_ln_pyb += ln_pxb_V[token_b] - ln_pyb_V[token_b]
279
+ if token_b == tokenizer.eos_token_id:
280
+ b_eos = True
281
+
282
+ if not quiet:
283
+ print(f'RESPONSE X:')
284
+ print(tokenizer.decode(response_a_L))
285
+ print(f'RESPONSE Y:')
286
+ print(tokenizer.decode(response_b_L))
287
+
288
+ if return_tokens:
289
+ return response_a_L, response_b_L
290
+ else:
291
+ return tokenizer.decode(response_a_L), tokenizer.decode(response_b_L)
292
+
293
+ # Alternative implementation.
294
+ @torch.no_grad()
295
+ def apoc_alt(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None):
296
+ if seed is not None:
297
+ torch.manual_seed(seed)
298
+
299
+ prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=True)
300
+ prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=True)
301
+ model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL)
302
+
303
+ logger.debug('PROMPT X:')
304
+ logger.debug(tokenizer.decode(prompt_x_BL[0]))
305
+ logger.debug('PROMPT Y:')
306
+ logger.debug(tokenizer.decode(prompt_y_BL[0]))
307
+
308
+ return _apoc_impl(model_pair, tokenizer, max_tokens)
309
+
310
+ LOGIT_DTYPE = torch.float64
311
+
312
+ class ModelPair:
313
+ def __init__(self, model_x, model_y, prompt_x_BL, prompt_y_BL):
314
+ self._model_x = model_x
315
+ self._model_y = model_y
316
+ self._prompt_x_BL = prompt_x_BL
317
+ self._prompt_y_BL = prompt_y_BL
318
+ self._is_swapped = False
319
+
320
+ def start(self):
321
+ # Return logprobs for the initial token.
322
+
323
+ outputs = self._model_x(self._prompt_x_BL, use_cache=True)
324
+ self._past_key_values_x = outputs.past_key_values
325
+ logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
326
+ lnpx_V = F.log_softmax(logits, dim=-1)
327
+
328
+ outputs = self._model_y(self._prompt_y_BL, use_cache=True)
329
+ self._past_key_values_y = outputs.past_key_values
330
+ logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
331
+ lnpy_V = F.log_softmax(logits, dim=-1)
332
+
333
+ return self._maybe_swap(lnpx_V, lnpy_V)
334
+
335
+ def step(self, token):
336
+ # Append the given token, then return logprobs for the next token.
337
+
338
+ forward_passes_start = time.perf_counter()
339
+
340
+ input_ids = torch.tensor([[token]], device=self._prompt_x_BL.device)
341
+
342
+ outputs = self._model_x(input_ids, past_key_values=self._past_key_values_x, use_cache=True)
343
+ self._past_key_values_x = outputs.past_key_values
344
+ logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
345
+ lnpx_V = F.log_softmax(logits, dim=-1)
346
+
347
+ outputs = self._model_y(input_ids, past_key_values=self._past_key_values_y, use_cache=True)
348
+ self._past_key_values_y = outputs.past_key_values
349
+ logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
350
+ lnpy_V = F.log_softmax(logits, dim=-1)
351
+
352
+ forward_passes_end = time.perf_counter()
353
+ logger.debug(f'Incremental forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms')
354
+
355
+ return self._maybe_swap(lnpx_V, lnpy_V)
356
+
357
+ def get_position(self):
358
+ # Return a position that can be rewound to.
359
+ return self._past_key_values_x, self._past_key_values_y
360
+
361
+ def rewind_to(self, position):
362
+ # Rewind the KV cache.
363
+ self._past_key_values_x, self._past_key_values_y = position
364
+
365
+ def swap_models(self):
366
+ # Exchange the order of the models.
367
+ self._is_swapped = not self._is_swapped
368
+
369
+ def _maybe_swap(self, a, b):
370
+ if self._is_swapped:
371
+ return b, a
372
+ else:
373
+ return a, b
374
+
375
+ def _apoc_impl(model_pair, tokenizer, max_tokens):
376
+ prefix = []
377
+ lnpx_V, lnpy_V = model_pair.start()
378
+ lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device)
379
+
380
+ while 1:
381
+ ln_peq_V = torch.minimum(
382
+ lnpx_V + F.relu(-lnpy_m_lnpx),
383
+ lnpy_V + F.relu(lnpy_m_lnpx),
384
+ )
385
+ peq_V = torch.exp(ln_peq_V)
386
+ peq = torch.sum(peq_V)
387
+
388
+ if torch.rand_like(peq) > peq:
389
+ logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)')
390
+ break
391
+ logger.debug(f'Extending common prefix ({peq*100.:.3f}%)')
392
+
393
+ token = torch.multinomial(peq_V, 1).item()
394
+ prefix.append(token)
395
+ lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
396
+
397
+ p_token = (peq_V[token] / peq).item()
398
+ logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
399
+
400
+ if token == tokenizer.eos_token_id or len(prefix) >= max_tokens:
401
+ return prefix, prefix
402
+
403
+ lnpx_V, lnpy_V = model_pair.step(token)
404
+
405
+ remaining_tokens = max_tokens - len(prefix)
406
+ split_pos = model_pair.get_position()
407
+ response_a = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx)
408
+ logger.debug('First suffix complete; rewinding')
409
+ model_pair.rewind_to(split_pos)
410
+ model_pair.swap_models()
411
+ response_b = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx)
412
+
413
+ return response_a, response_b
414
+
415
+ def _apoc_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
416
+ lnpy_m_lnpx = lnpy_m_lnpx.clone()
417
+ suffix = []
418
+ while 1:
419
+ wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx))
420
+ token = torch.multinomial(wx_V, 1).item()
421
+ suffix.append(token)
422
+ lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
423
+
424
+ p_token = (wx_V[token] / torch.sum(wx_V)).item()
425
+ logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
426
+
427
+ if token == tokenizer.eos_token_id or len(suffix) >= max_tokens:
428
+ return suffix
429
+
430
+ lnpx_V, lnpy_V = model_pair.step(token)
431
+
432
+ def generate_streaming(device, model, tokenizer, chat, max_tokens=512, seed=None):
433
+ """Stream a response using custom generation."""
434
+
435
+ prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=True)
436
+ logger.debug('PROMPT:')
437
+ logger.debug(tokenizer.decode(prompt_BL[0]))
438
+
439
+ if seed is not None:
440
+ torch.manual_seed(seed)
441
+
442
+ return _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens)
443
+
444
+ def _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens):
445
+ input_ids = prompt_BL
446
+ past_key_values = None
447
+
448
+ n_tokens = 0
449
+ while 1:
450
+ outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
451
+ past_key_values = outputs.past_key_values
452
+
453
+ logits = outputs.logits[0, -1, :]
454
+ p_V = F.softmax(logits, dim=-1)
455
+ token = torch.multinomial(p_V, num_samples=1).item()
456
+
457
+ p_token = p_V[token].item()
458
+ logger.debug(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
459
+
460
+ yield token
461
+ n_tokens += 1
462
+
463
+ if token == tokenizer.eos_token_id or n_tokens >= max_tokens:
464
+ break
465
+
466
+ input_ids = torch.tensor([[token]], device=device)
467
+
468
+ # APOC unconditional streaming
469
+ @torch.no_grad()
470
+ def apoc_streaming(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None):
471
+ if seed is not None:
472
+ torch.manual_seed(seed)
473
+
474
+ prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=True)
475
+ prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=True)
476
+ model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL)
477
+
478
+ logger.debug('PROMPT X:')
479
+ logger.debug(tokenizer.decode(prompt_x_BL[0]))
480
+ logger.debug('PROMPT Y:')
481
+ logger.debug(tokenizer.decode(prompt_y_BL[0]))
482
+
483
+ return _apoc_streaming_impl(model_pair, tokenizer, max_tokens)
484
+
485
+ def _apoc_streaming_impl(model_pair, tokenizer, max_tokens):
486
+ remaining_tokens = max_tokens
487
+ lnpx_V, lnpy_V = model_pair.start()
488
+ lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device)
489
+
490
+ while 1:
491
+ ln_peq_V = torch.minimum(
492
+ lnpx_V + F.relu(-lnpy_m_lnpx),
493
+ lnpy_V + F.relu(lnpy_m_lnpx),
494
+ )
495
+ peq_V = torch.exp(ln_peq_V)
496
+ peq = torch.sum(peq_V)
497
+
498
+ if torch.rand_like(peq) > peq:
499
+ logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)')
500
+ break
501
+ logger.debug(f'Extending common prefix ({peq*100.:.3f}%)')
502
+
503
+ token = torch.multinomial(peq_V, 1).item()
504
+ remaining_tokens -= 1
505
+ yield token, token
506
+ lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
507
+
508
+ p_token = (peq_V[token] / peq).item()
509
+ logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
510
+
511
+ if token == tokenizer.eos_token_id or remaining_tokens == 0:
512
+ return
513
+
514
+ lnpx_V, lnpy_V = model_pair.step(token)
515
+
516
+ split_pos = model_pair.get_position()
517
+ for token_a in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
518
+ yield token_a, None
519
+ logger.debug('Suffix a complete; rewinding')
520
+ model_pair.rewind_to(split_pos)
521
+ model_pair.swap_models()
522
+ for token_b in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx):
523
+ yield None, token_b
524
+ logger.debug('Suffix b complete')
525
+
526
+ def _apoc_streaming_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
527
+ remaining_tokens = max_tokens
528
+ lnpy_m_lnpx = lnpy_m_lnpx.clone()
529
+ while 1:
530
+ wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx))
531
+ token = torch.multinomial(wx_V, 1).item()
532
+ remaining_tokens -= 1
533
+ yield token
534
+ lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
535
+
536
+ p_token = (wx_V[token] / torch.sum(wx_V)).item()
537
+ logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
538
+
539
+ if token == tokenizer.eos_token_id or remaining_tokens == 0:
540
+ return
541
+
542
+ lnpx_V, lnpy_V = model_pair.step(token)
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio demo of streaming generation of multiple LLM response pairs.
2
+
3
+ import logging
4
+ import time
5
+ import html
6
+ import numpy as np
7
+ import gradio as gr
8
+ import util
9
+
10
+ # gr.DataFrame is currently bugged for updating values,
11
+ # so we must use raw HTML.
12
+ # https://github.com/gradio-app/gradio/issues/8160
13
+ def make_html_table(headers, data):
14
+ rows = ['<tr>' + ''.join(f'<th style="width: 50%">{h}</th>' for h in headers) + '</tr>\n']
15
+ for row in data:
16
+ rows.append('<tr>' + ''.join(f'<td style="width: 50%; font-family: monospace; white-space: pre-wrap;">{v}</td>' for v in row) + '</tr>\n')
17
+ return '<table style="width: 100%; table-layout: fixed">\n' + ''.join(rows) + '</table>\n'
18
+
19
+ def highlight_prefix(tokens, prefix_len):
20
+ prefix_tokens = tokens[:prefix_len]
21
+
22
+ s = tokenizer.decode(tokens, skip_special_tokens=True)
23
+ prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True)
24
+
25
+ s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s)))
26
+
27
+ prefix_html = html.escape(s[:s_lcp_len])
28
+ suffix_html = html.escape(s[s_lcp_len:])
29
+
30
+ #highlight_style = 'background-color: #FFFFAE;'
31
+ #highlight_style = 'text-decoration: underline;'
32
+ highlight_style = 'background-color: #90FF90;'
33
+
34
+ return f'<span style="{highlight_style}">{prefix_html}</span>{suffix_html}'
35
+
36
+ def format_response_pair(tokens_a, tokens_b):
37
+ # This is slightly convoluted, so as to properly handle grapheme clusters that span token boundaries.
38
+ token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b)
39
+ return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len)
40
+
41
+ HEADERS = ['Response (Left)', 'Response (Right)']
42
+ repo_id = "Qwen/Qwen2-0.5B-Instruct"
43
+
44
+ DRY_RUN = False
45
+
46
+ if DRY_RUN:
47
+ from load import load_tokenizer
48
+
49
+ tokenizer = load_tokenizer(repo_id)
50
+
51
+ def fn(max_tokens, num_responses, prompt_x, prompt_y):
52
+ rows = [['']*2 for i in range(num_responses)]
53
+
54
+ yield make_html_table(HEADERS, rows)
55
+
56
+ for j in range(num_responses):
57
+ response_raw_a = f'Sure!\n\n1 2 3 4 & 5.'
58
+ response_raw_b = f'Sure!\n\n1 2 3 4 5 & 6.'
59
+
60
+ response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0]
61
+ response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0]
62
+
63
+ steps = 1 + max(len(response_tok_a), len(response_tok_b))
64
+
65
+ for i in range(steps):
66
+ time.sleep(0.1)
67
+ prefix_tok_a = response_tok_a[:i]
68
+ prefix_tok_b = response_tok_b[:i]
69
+
70
+ content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b)
71
+
72
+ rows[j][0] = content_a
73
+ rows[j][1] = content_b
74
+
75
+ yield make_html_table(HEADERS, rows)
76
+ else:
77
+ from load import load_model
78
+ import algorithms
79
+
80
+ logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s')
81
+ algorithms.logger.setLevel(logging.INFO)
82
+
83
+ model, tokenizer = load_model(repo_id)
84
+
85
+ def make_chat(system_msg, prompt):
86
+ chat = [
87
+ {
88
+ 'role': 'system',
89
+ 'content': system_msg,
90
+ },
91
+ {
92
+ 'role': 'user',
93
+ 'content': prompt,
94
+ },
95
+ ]
96
+ return chat
97
+
98
+ def fn(max_tokens, num_responses, prompt_x, prompt_y):
99
+ rows = [['']*2 for i in range(num_responses)]
100
+ yield make_html_table(HEADERS, rows)
101
+
102
+ for j in range(num_responses):
103
+ system_msg = "You are a helpful assistant."
104
+
105
+ chat_x = make_chat(system_msg, prompt_x)
106
+ chat_y = make_chat(system_msg, prompt_y)
107
+
108
+ gen = algorithms.apoc_streaming(
109
+ 'cpu',
110
+ model,
111
+ model,
112
+ tokenizer,
113
+ chat_x,
114
+ chat_y,
115
+ max_tokens=max_tokens,
116
+ )
117
+ response_a_L = []
118
+ response_b_L = []
119
+ for token_a, token_b in gen:
120
+ dirty = False
121
+ if token_a is not None:
122
+ response_a_L.append(token_a)
123
+ dirty = True
124
+ if token_b is not None:
125
+ response_b_L.append(token_b)
126
+ dirty = True
127
+
128
+ if dirty:
129
+ content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L))
130
+
131
+ rows[j][0] = content_a
132
+ rows[j][1] = content_b
133
+
134
+ yield make_html_table(HEADERS, rows)
135
+
136
+ demo = gr.Interface(
137
+ fn=fn,
138
+ inputs=[
139
+ gr.Slider(1, 512, label='Max Tokens', value=48),
140
+ gr.Slider(1, 16, step=1, label='Num Responses', value=8),
141
+ gr.Textbox(label='Prompt (Left)'),
142
+ gr.Textbox(label='Prompt (Right)'),
143
+ ],
144
+ outputs=[
145
+ gr.HTML(),
146
+ ],
147
+ title='All-Prefix-Optimal Coupling',
148
+ description='Try similar prompts to see the effect of the difference between them. '
149
+ f'Model: `{repo_id}`.'
150
+ ,
151
+ examples=[
152
+ [48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'],
153
+ [48, 8, 'Tell me a joke.', 'Tell me a funny joke.'],
154
+ [48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'],
155
+ [48, 8, "What's the capital of Canada?", "What's the capital of France?"],
156
+ ],
157
+ )
158
+
159
+ demo.launch()
load.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code to load a model.
2
+
3
+ import os
4
+ import warnings
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
+
9
+ def load_model(repo_id, device_map=None, bnb=None, torch_dtype='auto'):
10
+ # Try our best to get deterministic results.
11
+ if device_map is not None:
12
+ # For determinism with CUDA >= 10.2, PyTorch says to use one of these.
13
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
14
+ #os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
15
+
16
+ torch.use_deterministic_algorithms(True)
17
+
18
+ # Ignore a spurious warning from huggingface_hub:
19
+ # https://github.com/huggingface/transformers/issues/30618
20
+ warnings.filterwarnings('ignore', message="`resume_download` is deprecated")
21
+
22
+ # Ignore a spurious warning from bitsandbytes.
23
+ warnings.filterwarnings('ignore', message="MatMul8bitLt: inputs will be cast from")
24
+
25
+ print(f'Loading model "{repo_id}" (bnb = "{bnb}")...')
26
+
27
+ # Ignore a spurious warning "Special tokens have been added..."
28
+ transformers.logging.set_verbosity_error()
29
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
30
+ transformers.logging.set_verbosity_warning()
31
+
32
+ bnb_config = None
33
+ if bnb == 'nf8':
34
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
35
+ if bnb == 'nf4':
36
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True)
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ repo_id,
40
+ torch_dtype=torch_dtype,
41
+ device_map=device_map,
42
+ quantization_config=bnb_config,
43
+ )
44
+
45
+ # Disable gradients to save memory.
46
+ for param in model.parameters():
47
+ param.requires_grad = False
48
+
49
+ # Try our best to get deterministic results.
50
+ model.eval()
51
+
52
+ print('Done loading model.')
53
+
54
+ return model, tokenizer
55
+
56
+ def load_tokenizer(repo_id):
57
+ # Ignore a spurious warning "Special tokens have been added..."
58
+ transformers.logging.set_verbosity_error()
59
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
60
+ transformers.logging.set_verbosity_warning()
61
+ return tokenizer
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ numpy==1.26.4
3
+ torch==2.2.2
4
+ transformers==4.40.2
util.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def longest_common_prefix(xs, ys):
4
+ min_len = min(len(xs), len(ys))
5
+ idxs = (xs[:min_len] != ys[:min_len]).nonzero()[0]
6
+ if len(idxs) > 0:
7
+ return idxs[0]
8
+ else:
9
+ return min_len
10
+
11
+ # Like np.cumsum, but with a leading zero.
12
+ def cumsum0(x, axis):
13
+ pad_width = len(x.shape) * [(0,0)]
14
+ pad_width[axis] = (1,0)
15
+ return np.cumsum(np.pad(x, pad_width), axis=axis)