Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import Optional, Union | |
import torch | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from kandinsky3.model.unet import UNet | |
from kandinsky3.movq import MoVQ | |
from kandinsky3.condition_encoders import T5TextConditionEncoder | |
from kandinsky3.condition_processors import T5TextConditionProcessor | |
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule | |
from .t2i_pipeline import Kandinsky3T2IPipeline | |
from .inpainting_pipeline import Kandinsky3InpaintingPipeline | |
def get_T2I_unet( | |
device: Union[str, torch.device], | |
weights_path: Optional[str] = None, | |
dtype: Union[str, torch.dtype] = torch.float32, | |
) -> (UNet, Optional[torch.Tensor], Optional[dict]): | |
unet = UNet( | |
model_channels=384, | |
num_channels=4, | |
init_channels=192, | |
time_embed_dim=1536, | |
context_dim=4096, | |
groups=32, | |
head_dim=64, | |
expansion_ratio=4, | |
compression_ratio=2, | |
dim_mult=(1, 2, 4, 8), | |
num_blocks=(3, 3, 3, 3), | |
add_cross_attention=(False, True, True, True), | |
add_self_attention=(False, True, True, True), | |
) | |
null_embedding = None | |
if weights_path: | |
state_dict = torch.load(weights_path, map_location=torch.device('cpu')) | |
null_embedding = state_dict['null_embedding'] | |
unet.load_state_dict(state_dict['unet']) | |
unet.to(device=device, dtype=dtype).eval() | |
return unet, null_embedding | |
def get_T5encoder( | |
device: Union[str, torch.device], | |
weights_path: str, | |
projection_name: str, | |
dtype: Union[str, torch.dtype] = torch.float32, | |
low_cpu_mem_usage: bool = True, | |
load_in_8bit: bool = False, | |
load_in_4bit: bool = False, | |
) -> (T5TextConditionProcessor, T5TextConditionEncoder): | |
tokens_length = 128 | |
context_dim = 4096 | |
processor = T5TextConditionProcessor(tokens_length, weights_path) | |
condition_encoder = T5TextConditionEncoder( | |
weights_path, context_dim, low_cpu_mem_usage=low_cpu_mem_usage, device=device, | |
dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit | |
) | |
if weights_path: | |
projections_weights_path = os.path.join(weights_path, projection_name) | |
state_dict = torch.load(projections_weights_path, map_location=torch.device('cpu')) | |
condition_encoder.projection.load_state_dict(state_dict) | |
condition_encoder.projection.to(device=device, dtype=dtype).eval() | |
return processor, condition_encoder | |
def get_movq( | |
device: Union[str, torch.device], | |
weights_path: Optional[str] = None, | |
dtype: Union[str, torch.dtype] = torch.float32, | |
) -> MoVQ: | |
generator_config = { | |
'double_z': False, | |
'z_channels': 4, | |
'resolution': 256, | |
'in_channels': 3, | |
'out_ch': 3, | |
'ch': 256, | |
'ch_mult': [1, 2, 2, 4], | |
'num_res_blocks': 2, | |
'attn_resolutions': [32], | |
'dropout': 0.0 | |
} | |
movq = MoVQ(generator_config) | |
if weights_path: | |
state_dict = torch.load(weights_path, map_location=torch.device('cpu')) | |
movq.load_state_dict(state_dict) | |
movq.to(device=device, dtype=dtype).eval() | |
return movq | |
def get_inpainting_unet( | |
device: Union[str, torch.device], | |
weights_path: Optional[str] = None, | |
dtype: Union[str, torch.dtype] = torch.float32, | |
) -> (UNet, Optional[torch.Tensor], Optional[dict]): | |
unet = UNet( | |
model_channels=384, | |
num_channels=9, | |
init_channels=192, | |
time_embed_dim=1536, | |
context_dim=4096, | |
groups=32, | |
head_dim=64, | |
expansion_ratio=4, | |
compression_ratio=2, | |
dim_mult=(1, 2, 4, 8), | |
num_blocks=(3, 3, 3, 3), | |
add_cross_attention=(False, True, True, True), | |
add_self_attention=(False, True, True, True), | |
) | |
null_embedding = None | |
if weights_path: | |
state_dict = torch.load(weights_path, map_location=torch.device('cpu')) | |
null_embedding = state_dict['null_embedding'] | |
unet.load_state_dict(state_dict['unet']) | |
unet.to(device=device, dtype=dtype).eval() | |
return unet, null_embedding | |
def get_T2I_pipeline( | |
device_map: Union[str, torch.device, dict], | |
dtype_map: Union[str, torch.dtype, dict] = torch.float32, | |
low_cpu_mem_usage: bool = True, | |
load_in_8bit: bool = False, | |
load_in_4bit: bool = False, | |
cache_dir: str = '/tmp/kandinsky3/', | |
unet_path: str = None, | |
text_encoder_path: str = None, | |
movq_path: str = None, | |
) -> Kandinsky3T2IPipeline: | |
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None)) | |
if not isinstance(device_map, dict): | |
device_map = { | |
'unet': device_map, 'text_encoder': device_map, 'movq': device_map | |
} | |
if not isinstance(dtype_map, dict): | |
dtype_map = { | |
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map | |
} | |
if unet_path is None: | |
unet_path = hf_hub_download( | |
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3.pt', cache_dir=cache_dir | |
) | |
if text_encoder_path is None: | |
text_encoder_path = snapshot_download( | |
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir | |
) | |
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder') | |
if movq_path is None: | |
movq_path = hf_hub_download( | |
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir | |
) | |
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet']) | |
processor, condition_encoder = get_T5encoder( | |
device_map['text_encoder'], text_encoder_path, 'projection.pt', dtype=dtype_map['text_encoder'], | |
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit | |
) | |
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq']) | |
return Kandinsky3T2IPipeline( | |
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, False | |
) | |
def get_T2I_Flash_pipeline( | |
device_map: Union[str, torch.device, dict], | |
dtype_map: Union[str, torch.dtype, dict] = torch.float32, | |
low_cpu_mem_usage: bool = True, | |
load_in_8bit: bool = False, | |
load_in_4bit: bool = False, | |
cache_dir: str = '/tmp/kandinsky3/', | |
unet_path: str = None, | |
text_encoder_path: str = None, | |
movq_path: str = None, | |
) -> Kandinsky3T2IPipeline: | |
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None)) | |
if not isinstance(device_map, dict): | |
device_map = { | |
'unet': device_map, 'text_encoder': device_map, 'movq': device_map | |
} | |
if not isinstance(dtype_map, dict): | |
dtype_map = { | |
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map | |
} | |
if unet_path is None: | |
unet_path = hf_hub_download( | |
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir | |
) | |
if text_encoder_path is None: | |
text_encoder_path = snapshot_download( | |
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir | |
) | |
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder') | |
if movq_path is None: | |
movq_path = hf_hub_download( | |
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir | |
) | |
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet']) | |
processor, condition_encoder = get_T5encoder( | |
device_map['text_encoder'], text_encoder_path, 'projection_flash.pt', dtype=dtype_map['text_encoder'], | |
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit | |
) | |
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq']) | |
return Kandinsky3T2IPipeline( | |
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, True | |
) | |
def get_inpainting_pipeline( | |
device_map: Union[str, torch.device, dict], | |
dtype_map: Union[str, torch.dtype, dict] = torch.float32, | |
low_cpu_mem_usage: bool = True, | |
load_in_8bit: bool = False, | |
load_in_4bit: bool = False, | |
cache_dir: str = '/tmp/kandinsky3/', | |
unet_path: str = None, | |
text_encoder_path: str = None, | |
movq_path: str = None, | |
) -> Kandinsky3InpaintingPipeline: | |
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None)) | |
if not isinstance(device_map, dict): | |
device_map = { | |
'unet': device_map, 'text_encoder': device_map, 'movq': device_map | |
} | |
if not isinstance(dtype_map, dict): | |
dtype_map = { | |
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map | |
} | |
if unet_path is None: | |
unet_path = hf_hub_download( | |
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_inpainting.pt', cache_dir=cache_dir | |
) | |
if text_encoder_path is None: | |
text_encoder_path = snapshot_download( | |
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir | |
) | |
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder') | |
if movq_path is None: | |
movq_path = hf_hub_download( | |
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir | |
) | |
unet, null_embedding = get_inpainting_unet(device_map['unet'], unet_path, dtype=dtype_map['unet']) | |
processor, condition_encoder = get_T5encoder( | |
device_map['text_encoder'], text_encoder_path, 'projection_inpainting.pt', dtype=dtype_map['text_encoder'], | |
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit | |
) | |
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq']) | |
return Kandinsky3InpaintingPipeline( | |
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq | |
) | |