customdiffusion360's picture
first commit
ad7bc89
raw
history blame
5.39 kB
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Optional, Tuple, Union
import torch
from einops import rearrange, repeat
from ...util import append_dims, default
logpy = logging.getLogger(__name__)
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
pass
class VanillaCFG(Guider):
def __init__(self, scale: float):
self.scale = scale
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class IdentityGuider(Guider):
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
return x
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
class LinearPredictionGuider(Guider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
self.min_scale = min_scale
self.max_scale = max_scale
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
additional_cond_keys = default(additional_cond_keys, [])
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
self.additional_cond_keys = additional_cond_keys
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class ScheduledCFGImgTextRef(Guider):
"""
From InstructPix2Pix
"""
def __init__(self, scale: float, scale_im: float):
self.scale = scale
self.scale_im = scale_im
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_ic, x_c = x.chunk(3)
x_pred = x_u + self.scale * (x_c - x_ic) + self.scale_im*(x_ic - x_u)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
b = uc[k].shape[0]
if k == "crossattn":
uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
c1, c2 = c[k].split([x.size(0), b - x.size(0)])
c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0)
else:
uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
c1, c2 = c[k].split([x.size(0), b - x.size(0)])
c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 3), torch.cat([s] * 3), c_out
class VanillaCFGImgRef(Guider):
"""
implements parallelized CFG
"""
def __init__(self, scale: float):
self.scale = scale
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
b = uc[k].shape[0]
if k == "crossattn":
uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
c1, c2 = c[k].split([x.size(0), b - x.size(0)])
c_out[k] = torch.cat((uc1, c1, uc2, c2), 0)
else:
uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)])
c1, c2 = c[k].split([x.size(0), b - x.size(0)])
c_out[k] = torch.cat((uc1, c1, uc2, c2), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out