wondervictor commited on
Commit
ba81d1b
1 Parent(s): 8a79548

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +1 -0
autoregressive/models/generate.py CHANGED
@@ -209,6 +209,7 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
209
  input_pos = torch.arange(0, T, device=device)
210
  next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength,**sampling_kwargs)
211
  seq[:, T:T+1] = next_token
 
212
 
213
  input_pos = torch.tensor([T], device=device, dtype=torch.int)
214
  generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
 
209
  input_pos = torch.arange(0, T, device=device)
210
  next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength,**sampling_kwargs)
211
  seq[:, T:T+1] = next_token
212
+ print(model.control_strength)
213
 
214
  input_pos = torch.tensor([T], device=device, dtype=torch.int)
215
  generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)