OminiControl / src /condition.py
Yuanshi's picture
add all
6ed1db6
raw
history blame
4.05 kB
import torch
from typing import Optional, Union, List, Tuple
from diffusers.pipelines import FluxPipeline
from PIL import Image, ImageFilter
import numpy as np
import cv2
condition_dict = {
"depth": 0,
"canny": 1,
"subject": 4,
"coloring": 6,
"deblurring": 7,
"fill": 9,
}
class Condition(object):
def __init__(
self,
condition_type: str,
raw_img: Union[Image.Image, torch.Tensor] = None,
condition: Union[Image.Image, torch.Tensor] = None,
mask=None,
) -> None:
self.condition_type = condition_type
assert raw_img is not None or condition is not None
if raw_img is not None:
self.condition = self.get_condition(condition_type, raw_img)
else:
self.condition = condition
# TODO: Add mask support
assert mask is None, "Mask not supported yet"
def get_condition(
self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
) -> Union[Image.Image, torch.Tensor]:
"""
Returns the condition image.
"""
if condition_type == "depth":
from transformers import pipeline
depth_pipe = pipeline(
task="depth-estimation",
model="LiheYoung/depth-anything-small-hf",
device="cuda",
)
source_image = raw_img.convert("RGB")
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
return condition_img
elif condition_type == "canny":
img = np.array(raw_img)
edges = cv2.Canny(img, 100, 200)
edges = Image.fromarray(edges).convert("RGB")
return edges
elif condition_type == "subject":
return raw_img
elif condition_type == "coloring":
return raw_img.convert("L").convert("RGB")
elif condition_type == "deblurring":
condition_image = (
raw_img.convert("RGB")
.filter(ImageFilter.GaussianBlur(10))
.convert("RGB")
)
return condition_image
elif condition_type == "fill":
return raw_img.convert("RGB")
return self.condition
@property
def type_id(self) -> int:
"""
Returns the type id of the condition.
"""
return condition_dict[self.condition_type]
@classmethod
def get_type_id(cls, condition_type: str) -> int:
"""
Returns the type id of the condition.
"""
return condition_dict[condition_type]
def _encode_image(self, pipe: FluxPipeline, cond_img: Image.Image) -> torch.Tensor:
"""
Encodes an image condition into tokens using the pipeline.
"""
cond_img = pipe.image_processor.preprocess(cond_img)
cond_img = cond_img.to(pipe.device).to(pipe.dtype)
cond_img = pipe.vae.encode(cond_img).latent_dist.sample()
cond_img = (
cond_img - pipe.vae.config.shift_factor
) * pipe.vae.config.scaling_factor
cond_tokens = pipe._pack_latents(cond_img, *cond_img.shape)
cond_ids = pipe._prepare_latent_image_ids(
cond_img.shape[0],
cond_img.shape[2],
cond_img.shape[3],
pipe.device,
pipe.dtype,
)
return cond_tokens, cond_ids
def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Encodes the condition into tokens, ids and type_id.
"""
if self.condition_type in [
"depth",
"canny",
"subject",
"coloring",
"deblurring",
"fill",
]:
tokens, ids = self._encode_image(pipe, self.condition)
else:
raise NotImplementedError(
f"Condition type {self.condition_type} not implemented"
)
type_id = torch.ones_like(ids[:, :1]) * self.type_id
return tokens, ids, type_id