Tony Lian
Use fast schedule for per-box generation to speed up
d871568
raw
history blame
1.22 kB
import torch
import warnings
def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate):
if fast_after_steps >= len(origial_timesteps) - 1:
return origial_timesteps
new_timesteps = torch.cat((origial_timesteps[:fast_after_steps], origial_timesteps[fast_after_steps+1::fast_rate]), dim=0)
return new_timesteps
def dynamically_adjust_inference_steps(scheduler, index, t):
prev_t = scheduler.timesteps[index+1] if index+1 < len(scheduler.timesteps) else -1
scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t)
if index+1 < len(scheduler.timesteps):
if scheduler.config.num_train_timesteps // scheduler.num_inference_steps != t - prev_t:
warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate")
else:
# as long as we hit final cumprob, it should be fine.
if scheduler.config.num_train_timesteps // scheduler.num_inference_steps > t - prev_t:
warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate")