nesterus
moved contents of presentations repo
d90acf0
raw
history blame
10.3 kB
import os
from typing import Optional, Union
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from kandinsky3.model.unet import UNet
from kandinsky3.movq import MoVQ
from kandinsky3.condition_encoders import T5TextConditionEncoder
from kandinsky3.condition_processors import T5TextConditionProcessor
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
from .t2i_pipeline import Kandinsky3T2IPipeline
from .inpainting_pipeline import Kandinsky3InpaintingPipeline
def get_T2I_unet(
device: Union[str, torch.device],
weights_path: Optional[str] = None,
dtype: Union[str, torch.dtype] = torch.float32,
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
unet = UNet(
model_channels=384,
num_channels=4,
init_channels=192,
time_embed_dim=1536,
context_dim=4096,
groups=32,
head_dim=64,
expansion_ratio=4,
compression_ratio=2,
dim_mult=(1, 2, 4, 8),
num_blocks=(3, 3, 3, 3),
add_cross_attention=(False, True, True, True),
add_self_attention=(False, True, True, True),
)
null_embedding = None
if weights_path:
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
null_embedding = state_dict['null_embedding']
unet.load_state_dict(state_dict['unet'])
unet.to(device=device, dtype=dtype).eval()
return unet, null_embedding
def get_T5encoder(
device: Union[str, torch.device],
weights_path: str,
projection_name: str,
dtype: Union[str, torch.dtype] = torch.float32,
low_cpu_mem_usage: bool = True,
load_in_8bit: bool = False,
load_in_4bit: bool = False,
) -> (T5TextConditionProcessor, T5TextConditionEncoder):
tokens_length = 128
context_dim = 4096
processor = T5TextConditionProcessor(tokens_length, weights_path)
condition_encoder = T5TextConditionEncoder(
weights_path, context_dim, low_cpu_mem_usage=low_cpu_mem_usage, device=device,
dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
)
if weights_path:
projections_weights_path = os.path.join(weights_path, projection_name)
state_dict = torch.load(projections_weights_path, map_location=torch.device('cpu'))
condition_encoder.projection.load_state_dict(state_dict)
condition_encoder.projection.to(device=device, dtype=dtype).eval()
return processor, condition_encoder
def get_movq(
device: Union[str, torch.device],
weights_path: Optional[str] = None,
dtype: Union[str, torch.dtype] = torch.float32,
) -> MoVQ:
generator_config = {
'double_z': False,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 256,
'ch_mult': [1, 2, 2, 4],
'num_res_blocks': 2,
'attn_resolutions': [32],
'dropout': 0.0
}
movq = MoVQ(generator_config)
if weights_path:
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
movq.load_state_dict(state_dict)
movq.to(device=device, dtype=dtype).eval()
return movq
def get_inpainting_unet(
device: Union[str, torch.device],
weights_path: Optional[str] = None,
dtype: Union[str, torch.dtype] = torch.float32,
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
unet = UNet(
model_channels=384,
num_channels=9,
init_channels=192,
time_embed_dim=1536,
context_dim=4096,
groups=32,
head_dim=64,
expansion_ratio=4,
compression_ratio=2,
dim_mult=(1, 2, 4, 8),
num_blocks=(3, 3, 3, 3),
add_cross_attention=(False, True, True, True),
add_self_attention=(False, True, True, True),
)
null_embedding = None
if weights_path:
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
null_embedding = state_dict['null_embedding']
unet.load_state_dict(state_dict['unet'])
unet.to(device=device, dtype=dtype).eval()
return unet, null_embedding
def get_T2I_pipeline(
device_map: Union[str, torch.device, dict],
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
low_cpu_mem_usage: bool = True,
load_in_8bit: bool = False,
load_in_4bit: bool = False,
cache_dir: str = '/tmp/kandinsky3/',
unet_path: str = None,
text_encoder_path: str = None,
movq_path: str = None,
) -> Kandinsky3T2IPipeline:
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
if not isinstance(device_map, dict):
device_map = {
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
}
if not isinstance(dtype_map, dict):
dtype_map = {
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
}
if unet_path is None:
unet_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3.pt', cache_dir=cache_dir
)
if text_encoder_path is None:
text_encoder_path = snapshot_download(
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
)
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
if movq_path is None:
movq_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
)
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
processor, condition_encoder = get_T5encoder(
device_map['text_encoder'], text_encoder_path, 'projection.pt', dtype=dtype_map['text_encoder'],
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
)
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
return Kandinsky3T2IPipeline(
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, False
)
def get_T2I_Flash_pipeline(
device_map: Union[str, torch.device, dict],
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
low_cpu_mem_usage: bool = True,
load_in_8bit: bool = False,
load_in_4bit: bool = False,
cache_dir: str = '/tmp/kandinsky3/',
unet_path: str = None,
text_encoder_path: str = None,
movq_path: str = None,
) -> Kandinsky3T2IPipeline:
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
if not isinstance(device_map, dict):
device_map = {
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
}
if not isinstance(dtype_map, dict):
dtype_map = {
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
}
if unet_path is None:
unet_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
)
if text_encoder_path is None:
text_encoder_path = snapshot_download(
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
)
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
if movq_path is None:
movq_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
)
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
processor, condition_encoder = get_T5encoder(
device_map['text_encoder'], text_encoder_path, 'projection_flash.pt', dtype=dtype_map['text_encoder'],
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
)
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
return Kandinsky3T2IPipeline(
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, True
)
def get_inpainting_pipeline(
device_map: Union[str, torch.device, dict],
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
low_cpu_mem_usage: bool = True,
load_in_8bit: bool = False,
load_in_4bit: bool = False,
cache_dir: str = '/tmp/kandinsky3/',
unet_path: str = None,
text_encoder_path: str = None,
movq_path: str = None,
) -> Kandinsky3InpaintingPipeline:
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
if not isinstance(device_map, dict):
device_map = {
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
}
if not isinstance(dtype_map, dict):
dtype_map = {
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
}
if unet_path is None:
unet_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_inpainting.pt', cache_dir=cache_dir
)
if text_encoder_path is None:
text_encoder_path = snapshot_download(
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
)
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
if movq_path is None:
movq_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
)
unet, null_embedding = get_inpainting_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
processor, condition_encoder = get_T5encoder(
device_map['text_encoder'], text_encoder_path, 'projection_inpainting.pt', dtype=dtype_map['text_encoder'],
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
)
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
return Kandinsky3InpaintingPipeline(
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq
)