Space for all-prefix-optimal coupling
Browse files- .gitignore +1 -0
- README.md +10 -4
- algorithms.py +542 -0
- app.py +159 -0
- load.py +61 -0
- requirements.txt +4 -0
- util.py +15 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
CHANGED
@@ -1,12 +1,18 @@
|
|
1 |
---
|
2 |
-
title: All
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
|
|
7 |
sdk_version: 4.41.0
|
8 |
app_file: app.py
|
|
|
|
|
|
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: All-Prefix-Optimal Coupling
|
3 |
+
emoji: π
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
python_version: 3.11
|
8 |
sdk_version: 4.41.0
|
9 |
app_file: app.py
|
10 |
+
short_description: Tightly pair LLM responses
|
11 |
+
models:
|
12 |
+
- Qwen/Qwen2-0.5B-Instruct
|
13 |
pinned: false
|
14 |
+
preload_from_hub:
|
15 |
+
- Qwen/Qwen2-0.5B-Instruct
|
16 |
---
|
17 |
|
18 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
algorithms.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Device-independent algorithms for LLM.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import transformers
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
# Helper to pull out the response tokens.
|
11 |
+
def isolate_responses_BL(output_BL, prompt_len, eos_token_id):
|
12 |
+
responses_BL = []
|
13 |
+
resp_BL = output_BL[:, prompt_len:]
|
14 |
+
for i in range(resp_BL.shape[0]):
|
15 |
+
resp_L = resp_BL[i]
|
16 |
+
resplen = resp_L.shape[0]
|
17 |
+
for j in range(resplen):
|
18 |
+
if resp_L[j] == eos_token_id:
|
19 |
+
resplen = j+1
|
20 |
+
break
|
21 |
+
response_L = resp_L[:resplen].cpu().detach().numpy()
|
22 |
+
responses_BL.append(response_L)
|
23 |
+
return responses_BL
|
24 |
+
|
25 |
+
def tokenize_prompt(device, tokenizer, chat, quiet=False):
|
26 |
+
# Tokenize the prompt.
|
27 |
+
prompt_BL = tokenizer.apply_chat_template(
|
28 |
+
[chat],
|
29 |
+
tokenizer=True, add_generation_prompt=True, return_tensors='pt'
|
30 |
+
).to(device)
|
31 |
+
if not quiet:
|
32 |
+
print('PROMPT:')
|
33 |
+
print(tokenizer.decode(prompt_BL[0]))
|
34 |
+
return prompt_BL
|
35 |
+
|
36 |
+
def generate(device, model, tokenizer, chat):
|
37 |
+
"""Generate a response using huggingface's generation."""
|
38 |
+
|
39 |
+
prompt_BL = tokenize_prompt(device, tokenizer, chat)
|
40 |
+
prompt_len = prompt_BL.shape[1]
|
41 |
+
|
42 |
+
# Generate response.
|
43 |
+
|
44 |
+
# Unfortunately, huggingface's generation code uses 'cumsum',
|
45 |
+
# which doesn't have a deterministic implementation.
|
46 |
+
torch.use_deterministic_algorithms(False)
|
47 |
+
generation_output = model.generate(
|
48 |
+
inputs=prompt_BL, max_new_tokens=512, do_sample=True,
|
49 |
+
return_dict_in_generate=True,
|
50 |
+
)
|
51 |
+
torch.use_deterministic_algorithms(True)
|
52 |
+
|
53 |
+
output_BL = generation_output.sequences
|
54 |
+
|
55 |
+
for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id):
|
56 |
+
print(f'RESPONSE:')
|
57 |
+
print(tokenizer.decode(response_L))
|
58 |
+
|
59 |
+
return tokenizer.decode(response_L)
|
60 |
+
|
61 |
+
def generate_with_logits(device, model, tokenizer, chat, seed=None):
|
62 |
+
# Huggingface generation that returns logits too.
|
63 |
+
|
64 |
+
if seed is not None:
|
65 |
+
transformers.set_seed(seed)
|
66 |
+
|
67 |
+
prompt_BL = tokenize_prompt(device, tokenizer, chat)
|
68 |
+
prompt_len = prompt_BL.shape[1]
|
69 |
+
|
70 |
+
# Unfortunately, huggingface's generation code uses 'cumsum',
|
71 |
+
# which doesn't have a deterministic implementation.
|
72 |
+
torch.use_deterministic_algorithms(False)
|
73 |
+
generation_output = model.generate(
|
74 |
+
inputs=prompt_BL, max_new_tokens=16, do_sample=True,
|
75 |
+
return_dict_in_generate=True, output_logits=True,
|
76 |
+
)
|
77 |
+
torch.use_deterministic_algorithms(True)
|
78 |
+
|
79 |
+
output_BL = generation_output.sequences
|
80 |
+
logits_BLV = torch.stack(generation_output.logits, axis=1)
|
81 |
+
|
82 |
+
for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id):
|
83 |
+
print(f'RESPONSE:')
|
84 |
+
print(tokenizer.decode(response_L))
|
85 |
+
|
86 |
+
return response_L, logits_BLV[0]
|
87 |
+
|
88 |
+
def response_logits(device, model, tokenizer, chat, response_L):
|
89 |
+
# Calculate logits using a single pass.
|
90 |
+
|
91 |
+
prompt_BL = tokenize_prompt(device, tokenizer, chat)
|
92 |
+
response_pt_L = torch.from_numpy(response_L[:-1])
|
93 |
+
|
94 |
+
# Concatenate along axis 1.
|
95 |
+
input_ids = torch.cat((prompt_BL, response_pt_L[None]), dim=1)
|
96 |
+
|
97 |
+
outputs = model(input_ids)
|
98 |
+
logits_BLV = outputs.logits
|
99 |
+
|
100 |
+
return logits_BLV[0][-len(response_L):]
|
101 |
+
|
102 |
+
def generate_custom(device, model, tokenizer, chat, max_tokens=512, seed=None, return_tokens=False, quiet=False, return_lnprobs=False):
|
103 |
+
"""Generate a response using custom generation."""
|
104 |
+
|
105 |
+
if seed is not None:
|
106 |
+
torch.manual_seed(seed)
|
107 |
+
|
108 |
+
prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=quiet)
|
109 |
+
|
110 |
+
# Generate response.
|
111 |
+
|
112 |
+
input_ids = prompt_BL
|
113 |
+
past_key_values = None
|
114 |
+
|
115 |
+
response_L = []
|
116 |
+
lnprobs_L = []
|
117 |
+
while 1:
|
118 |
+
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
|
119 |
+
past_key_values = outputs.past_key_values
|
120 |
+
|
121 |
+
logits = outputs.logits[0, -1, :]
|
122 |
+
p_V = F.softmax(logits, dim=-1)
|
123 |
+
token = torch.multinomial(p_V, num_samples=1).item()
|
124 |
+
response_L.append(token)
|
125 |
+
lnprobs_L.append(F.log_softmax(logits, dim=-1)[token].item())
|
126 |
+
|
127 |
+
p_token = p_V[token].item()
|
128 |
+
if not quiet:
|
129 |
+
print(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
|
130 |
+
|
131 |
+
if token == tokenizer.eos_token_id or len(response_L) >= max_tokens:
|
132 |
+
break
|
133 |
+
|
134 |
+
input_ids = torch.tensor([[token]], device=device)
|
135 |
+
|
136 |
+
if not quiet:
|
137 |
+
print(f'RESPONSE:')
|
138 |
+
print(tokenizer.decode(response_L))
|
139 |
+
|
140 |
+
if return_tokens:
|
141 |
+
if return_lnprobs:
|
142 |
+
return response_L, lnprobs_L
|
143 |
+
else:
|
144 |
+
return response_L
|
145 |
+
else:
|
146 |
+
assert not return_lnprobs
|
147 |
+
return tokenizer.decode(response_L)
|
148 |
+
|
149 |
+
def format_token(tokenizer, token_id):
|
150 |
+
return repr(tokenizer.decode(token_id))
|
151 |
+
|
152 |
+
def apoc(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None, return_tokens=False, quiet=False):
|
153 |
+
"""Generate a response using APOC unconditional sampling."""
|
154 |
+
|
155 |
+
if seed is not None:
|
156 |
+
torch.manual_seed(seed)
|
157 |
+
|
158 |
+
# This early implementation of the algorithm is numerically non-robust,
|
159 |
+
# so reduce problems by using high-precision floating-point.
|
160 |
+
logit_dtype = torch.float64
|
161 |
+
|
162 |
+
prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=quiet)
|
163 |
+
prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=quiet)
|
164 |
+
|
165 |
+
# Four variables are needed, since in the first iteration it depends on prompt (X vs Y),
|
166 |
+
# whereas in later iterations it depends on response (a vs b).
|
167 |
+
input_ids_xa = prompt_x_BL
|
168 |
+
input_ids_ya = prompt_y_BL
|
169 |
+
input_ids_xb = prompt_x_BL
|
170 |
+
input_ids_yb = prompt_y_BL
|
171 |
+
|
172 |
+
past_key_values_xa = None
|
173 |
+
past_key_values_ya = None
|
174 |
+
past_key_values_xb = None
|
175 |
+
past_key_values_yb = None
|
176 |
+
|
177 |
+
equal = True
|
178 |
+
a_eos = False
|
179 |
+
b_eos = False
|
180 |
+
def zero():
|
181 |
+
return torch.zeros(1, dtype=logit_dtype, device=device)
|
182 |
+
ln_pya_m_ln_pxa = zero()
|
183 |
+
ln_pxb_m_ln_pyb = zero()
|
184 |
+
|
185 |
+
response_a_L = []
|
186 |
+
response_b_L = []
|
187 |
+
i = 0
|
188 |
+
while 1:
|
189 |
+
if i >= max_tokens or (a_eos and b_eos): break
|
190 |
+
if not quiet:
|
191 |
+
print(f'Generating response token {i}')
|
192 |
+
i += 1
|
193 |
+
|
194 |
+
forward_passes_start = time.perf_counter()
|
195 |
+
|
196 |
+
if not a_eos:
|
197 |
+
outputs = model_x(input_ids_xa, past_key_values=past_key_values_xa, use_cache=True)
|
198 |
+
past_key_values_xa = outputs.past_key_values
|
199 |
+
logits = outputs.logits[0, -1, :].to(logit_dtype)
|
200 |
+
ln_pxa_V = F.log_softmax(logits, dim=-1)
|
201 |
+
|
202 |
+
outputs = model_y(input_ids_ya, past_key_values=past_key_values_ya, use_cache=True)
|
203 |
+
past_key_values_ya = outputs.past_key_values
|
204 |
+
logits = outputs.logits[0, -1, :].to(logit_dtype)
|
205 |
+
ln_pya_V = F.log_softmax(logits, dim=-1)
|
206 |
+
|
207 |
+
if not b_eos:
|
208 |
+
if equal:
|
209 |
+
# In equal mode, neither input_ids nor past_key_values depends on a vs b,
|
210 |
+
# so we can reuse the forward pass results for a 50% time savings.
|
211 |
+
assert not a_eos
|
212 |
+
past_key_values_xb = past_key_values_xa
|
213 |
+
ln_pxb_V = ln_pxa_V
|
214 |
+
past_key_values_yb = past_key_values_ya
|
215 |
+
ln_pyb_V = ln_pya_V
|
216 |
+
else:
|
217 |
+
outputs = model_x(input_ids_xb, past_key_values=past_key_values_xb, use_cache=True)
|
218 |
+
past_key_values_xb = outputs.past_key_values
|
219 |
+
logits = outputs.logits[0, -1, :].to(logit_dtype)
|
220 |
+
ln_pxb_V = F.log_softmax(logits, dim=-1)
|
221 |
+
|
222 |
+
outputs = model_y(input_ids_yb, past_key_values=past_key_values_yb, use_cache=True)
|
223 |
+
past_key_values_yb = outputs.past_key_values
|
224 |
+
logits = outputs.logits[0, -1, :].to(logit_dtype)
|
225 |
+
ln_pyb_V = F.log_softmax(logits, dim=-1)
|
226 |
+
|
227 |
+
forward_passes_end = time.perf_counter()
|
228 |
+
if not quiet:
|
229 |
+
print(f' Forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms')
|
230 |
+
|
231 |
+
if equal:
|
232 |
+
ln_pmeet_V = torch.minimum(
|
233 |
+
ln_pxa_V + torch.maximum(zero(), -ln_pya_m_ln_pxa),
|
234 |
+
ln_pya_V + torch.maximum(zero(), ln_pya_m_ln_pxa),
|
235 |
+
)
|
236 |
+
pmeet_V = torch.exp(ln_pmeet_V)
|
237 |
+
pmeet = torch.sum(pmeet_V)
|
238 |
+
|
239 |
+
if not quiet:
|
240 |
+
print(f' Equal mode (pmeet={pmeet*100.0:.3f}%)')
|
241 |
+
|
242 |
+
if torch.rand_like(pmeet) < pmeet:
|
243 |
+
token_a = token_b = torch.multinomial(pmeet_V, num_samples=1).item()
|
244 |
+
|
245 |
+
p_token_a = (pmeet_V[token_a] / pmeet).item()
|
246 |
+
if not quiet:
|
247 |
+
print(f' Sampled {format_token(tokenizer, token_a)} ({p_token_a*100.0:.1f}%)')
|
248 |
+
else:
|
249 |
+
if not quiet:
|
250 |
+
print(' Exited equal mode')
|
251 |
+
equal = False
|
252 |
+
|
253 |
+
if not equal:
|
254 |
+
if not a_eos:
|
255 |
+
wxt_V = torch.maximum(zero(), torch.exp(ln_pxa_V) - torch.exp(ln_pya_V + ln_pya_m_ln_pxa))
|
256 |
+
token_a = torch.multinomial(wxt_V, num_samples=1).item()
|
257 |
+
|
258 |
+
p_token_a = (wxt_V[token_a] / torch.sum(wxt_V)).item()
|
259 |
+
if not quiet:
|
260 |
+
print(f' Sampled token_a {format_token(tokenizer, token_a)} ({p_token_a*100.0:.3f}%)')
|
261 |
+
if not b_eos:
|
262 |
+
wyt_V = torch.maximum(zero(), torch.exp(ln_pyb_V) - torch.exp(ln_pxb_V + ln_pxb_m_ln_pyb))
|
263 |
+
token_b = torch.multinomial(wyt_V, num_samples=1).item()
|
264 |
+
|
265 |
+
p_token_b = (wyt_V[token_b] / torch.sum(wyt_V)).item()
|
266 |
+
if not quiet:
|
267 |
+
print(f' Sampled token_b {format_token(tokenizer, token_b)} ({p_token_b*100.0:.3f}%)')
|
268 |
+
|
269 |
+
if not a_eos:
|
270 |
+
response_a_L.append(token_a)
|
271 |
+
input_ids_xa = input_ids_ya = torch.tensor([[token_a]], device=device)
|
272 |
+
ln_pya_m_ln_pxa += ln_pya_V[token_a] - ln_pxa_V[token_a]
|
273 |
+
if token_a == tokenizer.eos_token_id:
|
274 |
+
a_eos = True
|
275 |
+
if not b_eos:
|
276 |
+
response_b_L.append(token_b)
|
277 |
+
input_ids_xb = input_ids_yb = torch.tensor([[token_b]], device=device)
|
278 |
+
ln_pxb_m_ln_pyb += ln_pxb_V[token_b] - ln_pyb_V[token_b]
|
279 |
+
if token_b == tokenizer.eos_token_id:
|
280 |
+
b_eos = True
|
281 |
+
|
282 |
+
if not quiet:
|
283 |
+
print(f'RESPONSE X:')
|
284 |
+
print(tokenizer.decode(response_a_L))
|
285 |
+
print(f'RESPONSE Y:')
|
286 |
+
print(tokenizer.decode(response_b_L))
|
287 |
+
|
288 |
+
if return_tokens:
|
289 |
+
return response_a_L, response_b_L
|
290 |
+
else:
|
291 |
+
return tokenizer.decode(response_a_L), tokenizer.decode(response_b_L)
|
292 |
+
|
293 |
+
# Alternative implementation.
|
294 |
+
@torch.no_grad()
|
295 |
+
def apoc_alt(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None):
|
296 |
+
if seed is not None:
|
297 |
+
torch.manual_seed(seed)
|
298 |
+
|
299 |
+
prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=True)
|
300 |
+
prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=True)
|
301 |
+
model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL)
|
302 |
+
|
303 |
+
logger.debug('PROMPT X:')
|
304 |
+
logger.debug(tokenizer.decode(prompt_x_BL[0]))
|
305 |
+
logger.debug('PROMPT Y:')
|
306 |
+
logger.debug(tokenizer.decode(prompt_y_BL[0]))
|
307 |
+
|
308 |
+
return _apoc_impl(model_pair, tokenizer, max_tokens)
|
309 |
+
|
310 |
+
LOGIT_DTYPE = torch.float64
|
311 |
+
|
312 |
+
class ModelPair:
|
313 |
+
def __init__(self, model_x, model_y, prompt_x_BL, prompt_y_BL):
|
314 |
+
self._model_x = model_x
|
315 |
+
self._model_y = model_y
|
316 |
+
self._prompt_x_BL = prompt_x_BL
|
317 |
+
self._prompt_y_BL = prompt_y_BL
|
318 |
+
self._is_swapped = False
|
319 |
+
|
320 |
+
def start(self):
|
321 |
+
# Return logprobs for the initial token.
|
322 |
+
|
323 |
+
outputs = self._model_x(self._prompt_x_BL, use_cache=True)
|
324 |
+
self._past_key_values_x = outputs.past_key_values
|
325 |
+
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
|
326 |
+
lnpx_V = F.log_softmax(logits, dim=-1)
|
327 |
+
|
328 |
+
outputs = self._model_y(self._prompt_y_BL, use_cache=True)
|
329 |
+
self._past_key_values_y = outputs.past_key_values
|
330 |
+
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
|
331 |
+
lnpy_V = F.log_softmax(logits, dim=-1)
|
332 |
+
|
333 |
+
return self._maybe_swap(lnpx_V, lnpy_V)
|
334 |
+
|
335 |
+
def step(self, token):
|
336 |
+
# Append the given token, then return logprobs for the next token.
|
337 |
+
|
338 |
+
forward_passes_start = time.perf_counter()
|
339 |
+
|
340 |
+
input_ids = torch.tensor([[token]], device=self._prompt_x_BL.device)
|
341 |
+
|
342 |
+
outputs = self._model_x(input_ids, past_key_values=self._past_key_values_x, use_cache=True)
|
343 |
+
self._past_key_values_x = outputs.past_key_values
|
344 |
+
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
|
345 |
+
lnpx_V = F.log_softmax(logits, dim=-1)
|
346 |
+
|
347 |
+
outputs = self._model_y(input_ids, past_key_values=self._past_key_values_y, use_cache=True)
|
348 |
+
self._past_key_values_y = outputs.past_key_values
|
349 |
+
logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE)
|
350 |
+
lnpy_V = F.log_softmax(logits, dim=-1)
|
351 |
+
|
352 |
+
forward_passes_end = time.perf_counter()
|
353 |
+
logger.debug(f'Incremental forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms')
|
354 |
+
|
355 |
+
return self._maybe_swap(lnpx_V, lnpy_V)
|
356 |
+
|
357 |
+
def get_position(self):
|
358 |
+
# Return a position that can be rewound to.
|
359 |
+
return self._past_key_values_x, self._past_key_values_y
|
360 |
+
|
361 |
+
def rewind_to(self, position):
|
362 |
+
# Rewind the KV cache.
|
363 |
+
self._past_key_values_x, self._past_key_values_y = position
|
364 |
+
|
365 |
+
def swap_models(self):
|
366 |
+
# Exchange the order of the models.
|
367 |
+
self._is_swapped = not self._is_swapped
|
368 |
+
|
369 |
+
def _maybe_swap(self, a, b):
|
370 |
+
if self._is_swapped:
|
371 |
+
return b, a
|
372 |
+
else:
|
373 |
+
return a, b
|
374 |
+
|
375 |
+
def _apoc_impl(model_pair, tokenizer, max_tokens):
|
376 |
+
prefix = []
|
377 |
+
lnpx_V, lnpy_V = model_pair.start()
|
378 |
+
lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device)
|
379 |
+
|
380 |
+
while 1:
|
381 |
+
ln_peq_V = torch.minimum(
|
382 |
+
lnpx_V + F.relu(-lnpy_m_lnpx),
|
383 |
+
lnpy_V + F.relu(lnpy_m_lnpx),
|
384 |
+
)
|
385 |
+
peq_V = torch.exp(ln_peq_V)
|
386 |
+
peq = torch.sum(peq_V)
|
387 |
+
|
388 |
+
if torch.rand_like(peq) > peq:
|
389 |
+
logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)')
|
390 |
+
break
|
391 |
+
logger.debug(f'Extending common prefix ({peq*100.:.3f}%)')
|
392 |
+
|
393 |
+
token = torch.multinomial(peq_V, 1).item()
|
394 |
+
prefix.append(token)
|
395 |
+
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
|
396 |
+
|
397 |
+
p_token = (peq_V[token] / peq).item()
|
398 |
+
logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
|
399 |
+
|
400 |
+
if token == tokenizer.eos_token_id or len(prefix) >= max_tokens:
|
401 |
+
return prefix, prefix
|
402 |
+
|
403 |
+
lnpx_V, lnpy_V = model_pair.step(token)
|
404 |
+
|
405 |
+
remaining_tokens = max_tokens - len(prefix)
|
406 |
+
split_pos = model_pair.get_position()
|
407 |
+
response_a = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx)
|
408 |
+
logger.debug('First suffix complete; rewinding')
|
409 |
+
model_pair.rewind_to(split_pos)
|
410 |
+
model_pair.swap_models()
|
411 |
+
response_b = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx)
|
412 |
+
|
413 |
+
return response_a, response_b
|
414 |
+
|
415 |
+
def _apoc_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
|
416 |
+
lnpy_m_lnpx = lnpy_m_lnpx.clone()
|
417 |
+
suffix = []
|
418 |
+
while 1:
|
419 |
+
wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx))
|
420 |
+
token = torch.multinomial(wx_V, 1).item()
|
421 |
+
suffix.append(token)
|
422 |
+
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
|
423 |
+
|
424 |
+
p_token = (wx_V[token] / torch.sum(wx_V)).item()
|
425 |
+
logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
|
426 |
+
|
427 |
+
if token == tokenizer.eos_token_id or len(suffix) >= max_tokens:
|
428 |
+
return suffix
|
429 |
+
|
430 |
+
lnpx_V, lnpy_V = model_pair.step(token)
|
431 |
+
|
432 |
+
def generate_streaming(device, model, tokenizer, chat, max_tokens=512, seed=None):
|
433 |
+
"""Stream a response using custom generation."""
|
434 |
+
|
435 |
+
prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=True)
|
436 |
+
logger.debug('PROMPT:')
|
437 |
+
logger.debug(tokenizer.decode(prompt_BL[0]))
|
438 |
+
|
439 |
+
if seed is not None:
|
440 |
+
torch.manual_seed(seed)
|
441 |
+
|
442 |
+
return _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens)
|
443 |
+
|
444 |
+
def _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens):
|
445 |
+
input_ids = prompt_BL
|
446 |
+
past_key_values = None
|
447 |
+
|
448 |
+
n_tokens = 0
|
449 |
+
while 1:
|
450 |
+
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
|
451 |
+
past_key_values = outputs.past_key_values
|
452 |
+
|
453 |
+
logits = outputs.logits[0, -1, :]
|
454 |
+
p_V = F.softmax(logits, dim=-1)
|
455 |
+
token = torch.multinomial(p_V, num_samples=1).item()
|
456 |
+
|
457 |
+
p_token = p_V[token].item()
|
458 |
+
logger.debug(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
|
459 |
+
|
460 |
+
yield token
|
461 |
+
n_tokens += 1
|
462 |
+
|
463 |
+
if token == tokenizer.eos_token_id or n_tokens >= max_tokens:
|
464 |
+
break
|
465 |
+
|
466 |
+
input_ids = torch.tensor([[token]], device=device)
|
467 |
+
|
468 |
+
# APOC unconditional streaming
|
469 |
+
@torch.no_grad()
|
470 |
+
def apoc_streaming(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None):
|
471 |
+
if seed is not None:
|
472 |
+
torch.manual_seed(seed)
|
473 |
+
|
474 |
+
prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=True)
|
475 |
+
prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=True)
|
476 |
+
model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL)
|
477 |
+
|
478 |
+
logger.debug('PROMPT X:')
|
479 |
+
logger.debug(tokenizer.decode(prompt_x_BL[0]))
|
480 |
+
logger.debug('PROMPT Y:')
|
481 |
+
logger.debug(tokenizer.decode(prompt_y_BL[0]))
|
482 |
+
|
483 |
+
return _apoc_streaming_impl(model_pair, tokenizer, max_tokens)
|
484 |
+
|
485 |
+
def _apoc_streaming_impl(model_pair, tokenizer, max_tokens):
|
486 |
+
remaining_tokens = max_tokens
|
487 |
+
lnpx_V, lnpy_V = model_pair.start()
|
488 |
+
lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device)
|
489 |
+
|
490 |
+
while 1:
|
491 |
+
ln_peq_V = torch.minimum(
|
492 |
+
lnpx_V + F.relu(-lnpy_m_lnpx),
|
493 |
+
lnpy_V + F.relu(lnpy_m_lnpx),
|
494 |
+
)
|
495 |
+
peq_V = torch.exp(ln_peq_V)
|
496 |
+
peq = torch.sum(peq_V)
|
497 |
+
|
498 |
+
if torch.rand_like(peq) > peq:
|
499 |
+
logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)')
|
500 |
+
break
|
501 |
+
logger.debug(f'Extending common prefix ({peq*100.:.3f}%)')
|
502 |
+
|
503 |
+
token = torch.multinomial(peq_V, 1).item()
|
504 |
+
remaining_tokens -= 1
|
505 |
+
yield token, token
|
506 |
+
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
|
507 |
+
|
508 |
+
p_token = (peq_V[token] / peq).item()
|
509 |
+
logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
|
510 |
+
|
511 |
+
if token == tokenizer.eos_token_id or remaining_tokens == 0:
|
512 |
+
return
|
513 |
+
|
514 |
+
lnpx_V, lnpy_V = model_pair.step(token)
|
515 |
+
|
516 |
+
split_pos = model_pair.get_position()
|
517 |
+
for token_a in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
|
518 |
+
yield token_a, None
|
519 |
+
logger.debug('Suffix a complete; rewinding')
|
520 |
+
model_pair.rewind_to(split_pos)
|
521 |
+
model_pair.swap_models()
|
522 |
+
for token_b in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx):
|
523 |
+
yield None, token_b
|
524 |
+
logger.debug('Suffix b complete')
|
525 |
+
|
526 |
+
def _apoc_streaming_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx):
|
527 |
+
remaining_tokens = max_tokens
|
528 |
+
lnpy_m_lnpx = lnpy_m_lnpx.clone()
|
529 |
+
while 1:
|
530 |
+
wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx))
|
531 |
+
token = torch.multinomial(wx_V, 1).item()
|
532 |
+
remaining_tokens -= 1
|
533 |
+
yield token
|
534 |
+
lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token]
|
535 |
+
|
536 |
+
p_token = (wx_V[token] / torch.sum(wx_V)).item()
|
537 |
+
logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)')
|
538 |
+
|
539 |
+
if token == tokenizer.eos_token_id or remaining_tokens == 0:
|
540 |
+
return
|
541 |
+
|
542 |
+
lnpx_V, lnpy_V = model_pair.step(token)
|
app.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gradio demo of streaming generation of multiple LLM response pairs.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
import html
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
import util
|
9 |
+
|
10 |
+
# gr.DataFrame is currently bugged for updating values,
|
11 |
+
# so we must use raw HTML.
|
12 |
+
# https://github.com/gradio-app/gradio/issues/8160
|
13 |
+
def make_html_table(headers, data):
|
14 |
+
rows = ['<tr>' + ''.join(f'<th style="width: 50%">{h}</th>' for h in headers) + '</tr>\n']
|
15 |
+
for row in data:
|
16 |
+
rows.append('<tr>' + ''.join(f'<td style="width: 50%; font-family: monospace; white-space: pre-wrap;">{v}</td>' for v in row) + '</tr>\n')
|
17 |
+
return '<table style="width: 100%; table-layout: fixed">\n' + ''.join(rows) + '</table>\n'
|
18 |
+
|
19 |
+
def highlight_prefix(tokens, prefix_len):
|
20 |
+
prefix_tokens = tokens[:prefix_len]
|
21 |
+
|
22 |
+
s = tokenizer.decode(tokens, skip_special_tokens=True)
|
23 |
+
prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True)
|
24 |
+
|
25 |
+
s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s)))
|
26 |
+
|
27 |
+
prefix_html = html.escape(s[:s_lcp_len])
|
28 |
+
suffix_html = html.escape(s[s_lcp_len:])
|
29 |
+
|
30 |
+
#highlight_style = 'background-color: #FFFFAE;'
|
31 |
+
#highlight_style = 'text-decoration: underline;'
|
32 |
+
highlight_style = 'background-color: #90FF90;'
|
33 |
+
|
34 |
+
return f'<span style="{highlight_style}">{prefix_html}</span>{suffix_html}'
|
35 |
+
|
36 |
+
def format_response_pair(tokens_a, tokens_b):
|
37 |
+
# This is slightly convoluted, so as to properly handle grapheme clusters that span token boundaries.
|
38 |
+
token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b)
|
39 |
+
return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len)
|
40 |
+
|
41 |
+
HEADERS = ['Response (Left)', 'Response (Right)']
|
42 |
+
repo_id = "Qwen/Qwen2-0.5B-Instruct"
|
43 |
+
|
44 |
+
DRY_RUN = False
|
45 |
+
|
46 |
+
if DRY_RUN:
|
47 |
+
from load import load_tokenizer
|
48 |
+
|
49 |
+
tokenizer = load_tokenizer(repo_id)
|
50 |
+
|
51 |
+
def fn(max_tokens, num_responses, prompt_x, prompt_y):
|
52 |
+
rows = [['']*2 for i in range(num_responses)]
|
53 |
+
|
54 |
+
yield make_html_table(HEADERS, rows)
|
55 |
+
|
56 |
+
for j in range(num_responses):
|
57 |
+
response_raw_a = f'Sure!\n\n1 2 3 4 & 5.'
|
58 |
+
response_raw_b = f'Sure!\n\n1 2 3 4 5 & 6.'
|
59 |
+
|
60 |
+
response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0]
|
61 |
+
response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0]
|
62 |
+
|
63 |
+
steps = 1 + max(len(response_tok_a), len(response_tok_b))
|
64 |
+
|
65 |
+
for i in range(steps):
|
66 |
+
time.sleep(0.1)
|
67 |
+
prefix_tok_a = response_tok_a[:i]
|
68 |
+
prefix_tok_b = response_tok_b[:i]
|
69 |
+
|
70 |
+
content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b)
|
71 |
+
|
72 |
+
rows[j][0] = content_a
|
73 |
+
rows[j][1] = content_b
|
74 |
+
|
75 |
+
yield make_html_table(HEADERS, rows)
|
76 |
+
else:
|
77 |
+
from load import load_model
|
78 |
+
import algorithms
|
79 |
+
|
80 |
+
logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s')
|
81 |
+
algorithms.logger.setLevel(logging.INFO)
|
82 |
+
|
83 |
+
model, tokenizer = load_model(repo_id)
|
84 |
+
|
85 |
+
def make_chat(system_msg, prompt):
|
86 |
+
chat = [
|
87 |
+
{
|
88 |
+
'role': 'system',
|
89 |
+
'content': system_msg,
|
90 |
+
},
|
91 |
+
{
|
92 |
+
'role': 'user',
|
93 |
+
'content': prompt,
|
94 |
+
},
|
95 |
+
]
|
96 |
+
return chat
|
97 |
+
|
98 |
+
def fn(max_tokens, num_responses, prompt_x, prompt_y):
|
99 |
+
rows = [['']*2 for i in range(num_responses)]
|
100 |
+
yield make_html_table(HEADERS, rows)
|
101 |
+
|
102 |
+
for j in range(num_responses):
|
103 |
+
system_msg = "You are a helpful assistant."
|
104 |
+
|
105 |
+
chat_x = make_chat(system_msg, prompt_x)
|
106 |
+
chat_y = make_chat(system_msg, prompt_y)
|
107 |
+
|
108 |
+
gen = algorithms.apoc_streaming(
|
109 |
+
'cpu',
|
110 |
+
model,
|
111 |
+
model,
|
112 |
+
tokenizer,
|
113 |
+
chat_x,
|
114 |
+
chat_y,
|
115 |
+
max_tokens=max_tokens,
|
116 |
+
)
|
117 |
+
response_a_L = []
|
118 |
+
response_b_L = []
|
119 |
+
for token_a, token_b in gen:
|
120 |
+
dirty = False
|
121 |
+
if token_a is not None:
|
122 |
+
response_a_L.append(token_a)
|
123 |
+
dirty = True
|
124 |
+
if token_b is not None:
|
125 |
+
response_b_L.append(token_b)
|
126 |
+
dirty = True
|
127 |
+
|
128 |
+
if dirty:
|
129 |
+
content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L))
|
130 |
+
|
131 |
+
rows[j][0] = content_a
|
132 |
+
rows[j][1] = content_b
|
133 |
+
|
134 |
+
yield make_html_table(HEADERS, rows)
|
135 |
+
|
136 |
+
demo = gr.Interface(
|
137 |
+
fn=fn,
|
138 |
+
inputs=[
|
139 |
+
gr.Slider(1, 512, label='Max Tokens', value=48),
|
140 |
+
gr.Slider(1, 16, step=1, label='Num Responses', value=8),
|
141 |
+
gr.Textbox(label='Prompt (Left)'),
|
142 |
+
gr.Textbox(label='Prompt (Right)'),
|
143 |
+
],
|
144 |
+
outputs=[
|
145 |
+
gr.HTML(),
|
146 |
+
],
|
147 |
+
title='All-Prefix-Optimal Coupling',
|
148 |
+
description='Try similar prompts to see the effect of the difference between them. '
|
149 |
+
f'Model: `{repo_id}`.'
|
150 |
+
,
|
151 |
+
examples=[
|
152 |
+
[48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'],
|
153 |
+
[48, 8, 'Tell me a joke.', 'Tell me a funny joke.'],
|
154 |
+
[48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'],
|
155 |
+
[48, 8, "What's the capital of Canada?", "What's the capital of France?"],
|
156 |
+
],
|
157 |
+
)
|
158 |
+
|
159 |
+
demo.launch()
|
load.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code to load a model.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
import torch
|
6 |
+
import transformers
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
8 |
+
|
9 |
+
def load_model(repo_id, device_map=None, bnb=None, torch_dtype='auto'):
|
10 |
+
# Try our best to get deterministic results.
|
11 |
+
if device_map is not None:
|
12 |
+
# For determinism with CUDA >= 10.2, PyTorch says to use one of these.
|
13 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
14 |
+
#os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
|
15 |
+
|
16 |
+
torch.use_deterministic_algorithms(True)
|
17 |
+
|
18 |
+
# Ignore a spurious warning from huggingface_hub:
|
19 |
+
# https://github.com/huggingface/transformers/issues/30618
|
20 |
+
warnings.filterwarnings('ignore', message="`resume_download` is deprecated")
|
21 |
+
|
22 |
+
# Ignore a spurious warning from bitsandbytes.
|
23 |
+
warnings.filterwarnings('ignore', message="MatMul8bitLt: inputs will be cast from")
|
24 |
+
|
25 |
+
print(f'Loading model "{repo_id}" (bnb = "{bnb}")...')
|
26 |
+
|
27 |
+
# Ignore a spurious warning "Special tokens have been added..."
|
28 |
+
transformers.logging.set_verbosity_error()
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
|
30 |
+
transformers.logging.set_verbosity_warning()
|
31 |
+
|
32 |
+
bnb_config = None
|
33 |
+
if bnb == 'nf8':
|
34 |
+
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
35 |
+
if bnb == 'nf4':
|
36 |
+
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
37 |
+
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
39 |
+
repo_id,
|
40 |
+
torch_dtype=torch_dtype,
|
41 |
+
device_map=device_map,
|
42 |
+
quantization_config=bnb_config,
|
43 |
+
)
|
44 |
+
|
45 |
+
# Disable gradients to save memory.
|
46 |
+
for param in model.parameters():
|
47 |
+
param.requires_grad = False
|
48 |
+
|
49 |
+
# Try our best to get deterministic results.
|
50 |
+
model.eval()
|
51 |
+
|
52 |
+
print('Done loading model.')
|
53 |
+
|
54 |
+
return model, tokenizer
|
55 |
+
|
56 |
+
def load_tokenizer(repo_id):
|
57 |
+
# Ignore a spurious warning "Special tokens have been added..."
|
58 |
+
transformers.logging.set_verbosity_error()
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
|
60 |
+
transformers.logging.set_verbosity_warning()
|
61 |
+
return tokenizer
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.22.2
|
2 |
+
numpy==1.26.4
|
3 |
+
torch==2.2.2
|
4 |
+
transformers==4.40.2
|
util.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def longest_common_prefix(xs, ys):
|
4 |
+
min_len = min(len(xs), len(ys))
|
5 |
+
idxs = (xs[:min_len] != ys[:min_len]).nonzero()[0]
|
6 |
+
if len(idxs) > 0:
|
7 |
+
return idxs[0]
|
8 |
+
else:
|
9 |
+
return min_len
|
10 |
+
|
11 |
+
# Like np.cumsum, but with a leading zero.
|
12 |
+
def cumsum0(x, axis):
|
13 |
+
pad_width = len(x.shape) * [(0,0)]
|
14 |
+
pad_width[axis] = (1,0)
|
15 |
+
return np.cumsum(np.pad(x, pad_width), axis=axis)
|