|
import pathlib |
|
from dataclasses import asdict, dataclass |
|
from enum import Enum |
|
from typing import Optional |
|
|
|
from omegaconf import OmegaConf |
|
|
|
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img, |
|
do_sample) |
|
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, |
|
DPMPP2SAncestralSampler, |
|
EulerAncestralSampler, |
|
EulerEDMSampler, |
|
HeunEDMSampler, |
|
LinearMultistepSampler) |
|
from sgm.util import load_model_from_config |
|
|
|
|
|
class ModelArchitecture(str, Enum): |
|
SD_2_1 = "stable-diffusion-v2-1" |
|
SD_2_1_768 = "stable-diffusion-v2-1-768" |
|
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" |
|
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" |
|
SDXL_V1_BASE = "stable-diffusion-xl-v1-base" |
|
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" |
|
|
|
|
|
class Sampler(str, Enum): |
|
EULER_EDM = "EulerEDMSampler" |
|
HEUN_EDM = "HeunEDMSampler" |
|
EULER_ANCESTRAL = "EulerAncestralSampler" |
|
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" |
|
DPMPP2M = "DPMPP2MSampler" |
|
LINEAR_MULTISTEP = "LinearMultistepSampler" |
|
|
|
|
|
class Discretization(str, Enum): |
|
LEGACY_DDPM = "LegacyDDPMDiscretization" |
|
EDM = "EDMDiscretization" |
|
|
|
|
|
class Guider(str, Enum): |
|
VANILLA = "VanillaCFG" |
|
IDENTITY = "IdentityGuider" |
|
|
|
|
|
class Thresholder(str, Enum): |
|
NONE = "None" |
|
|
|
|
|
@dataclass |
|
class SamplingParams: |
|
width: int = 1024 |
|
height: int = 1024 |
|
steps: int = 50 |
|
sampler: Sampler = Sampler.DPMPP2M |
|
discretization: Discretization = Discretization.LEGACY_DDPM |
|
guider: Guider = Guider.VANILLA |
|
thresholder: Thresholder = Thresholder.NONE |
|
scale: float = 6.0 |
|
aesthetic_score: float = 5.0 |
|
negative_aesthetic_score: float = 5.0 |
|
img2img_strength: float = 1.0 |
|
orig_width: int = 1024 |
|
orig_height: int = 1024 |
|
crop_coords_top: int = 0 |
|
crop_coords_left: int = 0 |
|
sigma_min: float = 0.0292 |
|
sigma_max: float = 14.6146 |
|
rho: float = 3.0 |
|
s_churn: float = 0.0 |
|
s_tmin: float = 0.0 |
|
s_tmax: float = 999.0 |
|
s_noise: float = 1.0 |
|
eta: float = 1.0 |
|
order: int = 4 |
|
|
|
|
|
@dataclass |
|
class SamplingSpec: |
|
width: int |
|
height: int |
|
channels: int |
|
factor: int |
|
is_legacy: bool |
|
config: str |
|
ckpt: str |
|
is_guided: bool |
|
|
|
|
|
model_specs = { |
|
ModelArchitecture.SD_2_1: SamplingSpec( |
|
height=512, |
|
width=512, |
|
channels=4, |
|
factor=8, |
|
is_legacy=True, |
|
config="sd_2_1.yaml", |
|
ckpt="v2-1_512-ema-pruned.safetensors", |
|
is_guided=True, |
|
), |
|
ModelArchitecture.SD_2_1_768: SamplingSpec( |
|
height=768, |
|
width=768, |
|
channels=4, |
|
factor=8, |
|
is_legacy=True, |
|
config="sd_2_1_768.yaml", |
|
ckpt="v2-1_768-ema-pruned.safetensors", |
|
is_guided=True, |
|
), |
|
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( |
|
height=1024, |
|
width=1024, |
|
channels=4, |
|
factor=8, |
|
is_legacy=False, |
|
config="sd_xl_base.yaml", |
|
ckpt="sd_xl_base_0.9.safetensors", |
|
is_guided=True, |
|
), |
|
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( |
|
height=1024, |
|
width=1024, |
|
channels=4, |
|
factor=8, |
|
is_legacy=True, |
|
config="sd_xl_refiner.yaml", |
|
ckpt="sd_xl_refiner_0.9.safetensors", |
|
is_guided=True, |
|
), |
|
ModelArchitecture.SDXL_V1_BASE: SamplingSpec( |
|
height=1024, |
|
width=1024, |
|
channels=4, |
|
factor=8, |
|
is_legacy=False, |
|
config="sd_xl_base.yaml", |
|
ckpt="sd_xl_base_1.0.safetensors", |
|
is_guided=True, |
|
), |
|
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( |
|
height=1024, |
|
width=1024, |
|
channels=4, |
|
factor=8, |
|
is_legacy=True, |
|
config="sd_xl_refiner.yaml", |
|
ckpt="sd_xl_refiner_1.0.safetensors", |
|
is_guided=True, |
|
), |
|
} |
|
|
|
|
|
class SamplingPipeline: |
|
def __init__( |
|
self, |
|
model_id: ModelArchitecture, |
|
model_path="checkpoints", |
|
config_path="configs/inference", |
|
device="cuda", |
|
use_fp16=True, |
|
) -> None: |
|
if model_id not in model_specs: |
|
raise ValueError(f"Model {model_id} not supported") |
|
self.model_id = model_id |
|
self.specs = model_specs[self.model_id] |
|
self.config = str(pathlib.Path(config_path, self.specs.config)) |
|
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) |
|
self.device = device |
|
self.model = self._load_model(device=device, use_fp16=use_fp16) |
|
|
|
def _load_model(self, device="cuda", use_fp16=True): |
|
config = OmegaConf.load(self.config) |
|
model = load_model_from_config(config, self.ckpt) |
|
if model is None: |
|
raise ValueError(f"Model {self.model_id} could not be loaded") |
|
model.to(device) |
|
if use_fp16: |
|
model.conditioner.half() |
|
model.model.half() |
|
return model |
|
|
|
def text_to_image( |
|
self, |
|
params: SamplingParams, |
|
prompt: str, |
|
negative_prompt: str = "", |
|
samples: int = 1, |
|
return_latents: bool = False, |
|
): |
|
sampler = get_sampler_config(params) |
|
value_dict = asdict(params) |
|
value_dict["prompt"] = prompt |
|
value_dict["negative_prompt"] = negative_prompt |
|
value_dict["target_width"] = params.width |
|
value_dict["target_height"] = params.height |
|
return do_sample( |
|
self.model, |
|
sampler, |
|
value_dict, |
|
samples, |
|
params.height, |
|
params.width, |
|
self.specs.channels, |
|
self.specs.factor, |
|
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], |
|
return_latents=return_latents, |
|
filter=None, |
|
) |
|
|
|
def image_to_image( |
|
self, |
|
params: SamplingParams, |
|
image, |
|
prompt: str, |
|
negative_prompt: str = "", |
|
samples: int = 1, |
|
return_latents: bool = False, |
|
): |
|
sampler = get_sampler_config(params) |
|
|
|
if params.img2img_strength < 1.0: |
|
sampler.discretization = Img2ImgDiscretizationWrapper( |
|
sampler.discretization, |
|
strength=params.img2img_strength, |
|
) |
|
height, width = image.shape[2], image.shape[3] |
|
value_dict = asdict(params) |
|
value_dict["prompt"] = prompt |
|
value_dict["negative_prompt"] = negative_prompt |
|
value_dict["target_width"] = width |
|
value_dict["target_height"] = height |
|
return do_img2img( |
|
image, |
|
self.model, |
|
sampler, |
|
value_dict, |
|
samples, |
|
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], |
|
return_latents=return_latents, |
|
filter=None, |
|
) |
|
|
|
def refiner( |
|
self, |
|
params: SamplingParams, |
|
image, |
|
prompt: str, |
|
negative_prompt: Optional[str] = None, |
|
samples: int = 1, |
|
return_latents: bool = False, |
|
): |
|
sampler = get_sampler_config(params) |
|
value_dict = { |
|
"orig_width": image.shape[3] * 8, |
|
"orig_height": image.shape[2] * 8, |
|
"target_width": image.shape[3] * 8, |
|
"target_height": image.shape[2] * 8, |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"crop_coords_top": 0, |
|
"crop_coords_left": 0, |
|
"aesthetic_score": 6.0, |
|
"negative_aesthetic_score": 2.5, |
|
} |
|
|
|
return do_img2img( |
|
image, |
|
self.model, |
|
sampler, |
|
value_dict, |
|
samples, |
|
skip_encode=True, |
|
return_latents=return_latents, |
|
filter=None, |
|
) |
|
|
|
|
|
def get_guider_config(params: SamplingParams): |
|
if params.guider == Guider.IDENTITY: |
|
guider_config = { |
|
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" |
|
} |
|
elif params.guider == Guider.VANILLA: |
|
scale = params.scale |
|
|
|
thresholder = params.thresholder |
|
|
|
if thresholder == Thresholder.NONE: |
|
dyn_thresh_config = { |
|
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" |
|
} |
|
else: |
|
raise NotImplementedError |
|
|
|
guider_config = { |
|
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", |
|
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, |
|
} |
|
else: |
|
raise NotImplementedError |
|
return guider_config |
|
|
|
|
|
def get_discretization_config(params: SamplingParams): |
|
if params.discretization == Discretization.LEGACY_DDPM: |
|
discretization_config = { |
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", |
|
} |
|
elif params.discretization == Discretization.EDM: |
|
discretization_config = { |
|
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", |
|
"params": { |
|
"sigma_min": params.sigma_min, |
|
"sigma_max": params.sigma_max, |
|
"rho": params.rho, |
|
}, |
|
} |
|
else: |
|
raise ValueError(f"unknown discretization {params.discretization}") |
|
return discretization_config |
|
|
|
|
|
def get_sampler_config(params: SamplingParams): |
|
discretization_config = get_discretization_config(params) |
|
guider_config = get_guider_config(params) |
|
sampler = None |
|
if params.sampler == Sampler.EULER_EDM: |
|
return EulerEDMSampler( |
|
num_steps=params.steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
s_churn=params.s_churn, |
|
s_tmin=params.s_tmin, |
|
s_tmax=params.s_tmax, |
|
s_noise=params.s_noise, |
|
verbose=True, |
|
) |
|
if params.sampler == Sampler.HEUN_EDM: |
|
return HeunEDMSampler( |
|
num_steps=params.steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
s_churn=params.s_churn, |
|
s_tmin=params.s_tmin, |
|
s_tmax=params.s_tmax, |
|
s_noise=params.s_noise, |
|
verbose=True, |
|
) |
|
if params.sampler == Sampler.EULER_ANCESTRAL: |
|
return EulerAncestralSampler( |
|
num_steps=params.steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
eta=params.eta, |
|
s_noise=params.s_noise, |
|
verbose=True, |
|
) |
|
if params.sampler == Sampler.DPMPP2S_ANCESTRAL: |
|
return DPMPP2SAncestralSampler( |
|
num_steps=params.steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
eta=params.eta, |
|
s_noise=params.s_noise, |
|
verbose=True, |
|
) |
|
if params.sampler == Sampler.DPMPP2M: |
|
return DPMPP2MSampler( |
|
num_steps=params.steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
verbose=True, |
|
) |
|
if params.sampler == Sampler.LINEAR_MULTISTEP: |
|
return LinearMultistepSampler( |
|
num_steps=params.steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
order=params.order, |
|
verbose=True, |
|
) |
|
|
|
raise ValueError(f"unknown sampler {params.sampler}!") |
|
|