Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Commit
•
ba81d1b
1
Parent(s):
8a79548
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
@@ -209,6 +209,7 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
|
|
209 |
input_pos = torch.arange(0, T, device=device)
|
210 |
next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength,**sampling_kwargs)
|
211 |
seq[:, T:T+1] = next_token
|
|
|
212 |
|
213 |
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
214 |
generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
|
|
|
209 |
input_pos = torch.arange(0, T, device=device)
|
210 |
next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength,**sampling_kwargs)
|
211 |
seq[:, T:T+1] = next_token
|
212 |
+
print(model.control_strength)
|
213 |
|
214 |
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
215 |
generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
|