Spaces:
Runtime error
Runtime error
from typing import Any, Optional, Union, Tuple, Dict, List | |
import os | |
import random | |
import math | |
import time | |
import numpy as np | |
from tqdm.auto import tqdm, trange | |
import torch | |
from torch.utils.data import DataLoader | |
import jax | |
import jax.numpy as jnp | |
import optax | |
from flax import jax_utils, traverse_util | |
from flax.core.frozen_dict import FrozenDict | |
from flax.training.train_state import TrainState | |
from flax.training.common_utils import shard | |
# convert 2D -> 3D | |
from diffusers import FlaxUNet2DConditionModel | |
# inference test, run on these on cpu | |
from diffusers import AutoencoderKL | |
from diffusers.schedulers.scheduling_ddim_flax import FlaxDDIMScheduler, DDIMSchedulerState | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from PIL import Image | |
from .flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel | |
def seed_all(seed: int) -> jax.random.PRNGKeyArray: | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
rng = jax.random.PRNGKey(seed) | |
return rng | |
def count_params( | |
params: Union[Dict[str, Any], | |
FrozenDict[str, Any]], | |
filter_name: Optional[str] = None | |
) -> int: | |
p: Dict[Tuple[str], jax.Array] = traverse_util.flatten_dict(params) | |
cc = 0 | |
for k in p: | |
if filter_name is not None: | |
if filter_name in ' '.join(k): | |
cc += len(p[k].flatten()) | |
else: | |
cc += len(p[k].flatten()) | |
return cc | |
def map_2d_to_pseudo3d( | |
params2d: Dict[str, Any], | |
params3d: Dict[str, Any], | |
verbose: bool = True | |
) -> Dict[str, Any]: | |
params2d = traverse_util.flatten_dict(params2d) | |
params3d = traverse_util.flatten_dict(params3d) | |
new_params = dict() | |
for k in params3d: | |
if 'spatial_conv' in k: | |
k2d = list(k) | |
k2d.remove('spatial_conv') | |
k2d = tuple(k2d) | |
if verbose: | |
tqdm.write(f'Spatial: {k} <- {k2d}') | |
p = params2d[k2d] | |
elif k not in params2d: | |
if verbose: | |
tqdm.write(f'Missing: {k}') | |
p = params3d[k] | |
else: | |
p = params2d[k] | |
assert p.shape == params3d[k].shape, f'shape mismatch: {k}: {p.shape} != {params3d[k].shape}' | |
new_params[k] = p | |
new_params = traverse_util.unflatten_dict(new_params) | |
return new_params | |
class FlaxTrainerUNetPseudo3D: | |
def __init__(self, | |
model_path: str, | |
from_pt: bool = True, | |
convert2d: bool = False, | |
sample_size: Tuple[int, int] = (64, 64), | |
seed: int = 0, | |
dtype: str = 'float32', | |
param_dtype: str = 'float32', | |
only_temporal: bool = True, | |
use_memory_efficient_attention = False, | |
verbose: bool = True | |
) -> None: | |
self.verbose = verbose | |
self.tracker: Optional['wandb.sdk.wandb_run.Run'] = None | |
self._use_wandb: bool = False | |
self._tracker_meta: Dict[str, Union[float, int]] = { | |
't00': 0.0, | |
't0': 0.0, | |
'step0': 0 | |
} | |
self.log('Init JAX') | |
self.num_devices = jax.device_count() | |
self.log(f'Device count: {self.num_devices}') | |
self.seed = seed | |
self.rng: jax.random.PRNGKeyArray = seed_all(self.seed) | |
self.sample_size = sample_size | |
if dtype == 'float32': | |
self.dtype = jnp.float32 | |
elif dtype == 'bfloat16': | |
self.dtype = jnp.bfloat16 | |
elif dtype == 'float16': | |
self.dtype = jnp.float16 | |
else: | |
raise ValueError(f'unknown type: {dtype}') | |
self.dtype_str: str = dtype | |
if param_dtype not in ['float32', 'bfloat16', 'float16']: | |
raise ValueError(f'unknown parameter type: {param_dtype}') | |
self.param_dtype = param_dtype | |
self._load_models( | |
model_path = model_path, | |
convert2d = convert2d, | |
from_pt = from_pt, | |
use_memory_efficient_attention = use_memory_efficient_attention | |
) | |
self._mark_parameters(only_temporal = only_temporal) | |
# optionally for validation + sampling | |
self.tokenizer: Optional[CLIPTokenizer] = None | |
self.text_encoder: Optional[CLIPTextModel] = None | |
self.vae: Optional[AutoencoderKL] = None | |
self.ddim: Optional[Tuple[FlaxDDIMScheduler, DDIMSchedulerState]] = None | |
def log(self, message: Any) -> None: | |
if self.verbose and jax.process_index() == 0: | |
tqdm.write(str(message)) | |
def log_metrics(self, metrics: dict, step: int, epoch: int) -> None: | |
if jax.process_index() > 0 or (not self.verbose and self.tracker is None): | |
return | |
now = time.monotonic() | |
log_data = { | |
'train/step': step, | |
'train/epoch': epoch, | |
'train/steps_per_sec': (step - self._tracker_meta['step0']) / (now - self._tracker_meta['t0']), | |
**{ f'train/{k}': v for k, v in metrics.items() } | |
} | |
self._tracker_meta['t0'] = now | |
self._tracker_meta['step0'] = step | |
self.log(log_data) | |
if self.tracker is not None: | |
self.tracker.log(log_data, step = step) | |
def enable_wandb(self, enable: bool = True) -> None: | |
self._use_wandb = enable | |
def _setup_wandb(self, config: Dict[str, Any] = dict()) -> None: | |
import wandb | |
import wandb.sdk | |
self.tracker: wandb.sdk.wandb_run.Run = wandb.init( | |
config = config, | |
settings = wandb.sdk.Settings( | |
username = 'anon', | |
host = 'anon', | |
email = 'anon', | |
root_dir = 'anon', | |
_executable = 'anon', | |
_disable_stats = True, | |
_disable_meta = True, | |
disable_code = True, | |
disable_git = True | |
) # pls don't log sensitive data like system user names. also, fuck you for even trying. | |
) | |
def _init_tracker_meta(self) -> None: | |
now = time.monotonic() | |
self._tracker_meta = { | |
't00': now, | |
't0': now, | |
'step0': 0 | |
} | |
def _load_models(self, | |
model_path: str, | |
convert2d: bool, | |
from_pt: bool, | |
use_memory_efficient_attention: bool | |
) -> None: | |
self.log(f'Load pretrained from {model_path}') | |
if convert2d: | |
self.log(' Convert 2D model to Pseudo3D') | |
self.log(' Initiate Pseudo3D model') | |
config = UNetPseudo3DConditionModel.load_config(model_path, subfolder = 'unet') | |
model = UNetPseudo3DConditionModel.from_config( | |
config, | |
sample_size = self.sample_size, | |
dtype = self.dtype, | |
param_dtype = self.param_dtype, | |
use_memory_efficient_attention = use_memory_efficient_attention | |
) | |
params: Dict[str, Any] = model.init_weights(self.rng).unfreeze() | |
self.log(' Load 2D model') | |
model2d, params2d = FlaxUNet2DConditionModel.from_pretrained( | |
model_path, | |
subfolder = 'unet', | |
dtype = self.dtype, | |
from_pt = from_pt | |
) | |
self.log(' Map 2D -> 3D') | |
params = map_2d_to_pseudo3d(params2d, params, verbose = self.verbose) | |
del params2d | |
del model2d | |
del config | |
else: | |
model, params = UNetPseudo3DConditionModel.from_pretrained( | |
model_path, | |
subfolder = 'unet', | |
from_pt = from_pt, | |
sample_size = self.sample_size, | |
dtype = self.dtype, | |
param_dtype = self.param_dtype, | |
use_memory_efficient_attention = use_memory_efficient_attention | |
) | |
self.log(f'Cast parameters to {model.param_dtype}') | |
if model.param_dtype == 'float32': | |
params = model.to_fp32(params) | |
elif model.param_dtype == 'float16': | |
params = model.to_fp16(params) | |
elif model.param_dtype == 'bfloat16': | |
params = model.to_bf16(params) | |
self.pretrained_model = model_path | |
self.model: UNetPseudo3DConditionModel = model | |
self.params: FrozenDict[str, Any] = FrozenDict(params) | |
def _mark_parameters(self, only_temporal: bool) -> None: | |
self.log('Mark training parameters') | |
if only_temporal: | |
self.log('Only training temporal layers') | |
if only_temporal: | |
param_partitions = traverse_util.path_aware_map( | |
lambda path, _: 'trainable' if 'temporal' in ' '.join(path) else 'frozen', self.params | |
) | |
else: | |
param_partitions = traverse_util.path_aware_map( | |
lambda *_: 'trainable', self.params | |
) | |
self.only_temporal = only_temporal | |
self.param_partitions: FrozenDict[str, Any] = FrozenDict(param_partitions) | |
self.log(f'Total parameters: {count_params(self.params)}') | |
self.log(f'Temporal parameters: {count_params(self.params, "temporal")}') | |
def _load_inference_models(self) -> None: | |
assert jax.process_index() == 0, 'not main process' | |
if self.text_encoder is None: | |
self.log('Load text encoder') | |
self.text_encoder = CLIPTextModel.from_pretrained( | |
self.pretrained_model, | |
subfolder = 'text_encoder' | |
) | |
if self.tokenizer is None: | |
self.log('Load tokenizer') | |
self.tokenizer = CLIPTokenizer.from_pretrained( | |
self.pretrained_model, | |
subfolder = 'tokenizer' | |
) | |
if self.vae is None: | |
self.log('Load vae') | |
self.vae = AutoencoderKL.from_pretrained( | |
self.pretrained_model, | |
subfolder = 'vae' | |
) | |
if self.ddim is None: | |
self.log('Load ddim scheduler') | |
# tuple(scheduler , scheduler state) | |
self.ddim = FlaxDDIMScheduler.from_pretrained( | |
self.pretrained_model, | |
subfolder = 'scheduler', | |
from_pt = True | |
) | |
def _unload_inference_models(self) -> None: | |
self.text_encoder = None | |
self.tokenizer = None | |
self.vae = None | |
self.ddim = None | |
def sample(self, | |
params: Union[Dict[str, Any], FrozenDict[str, Any]], | |
prompt: str, | |
image_path: str, | |
num_frames: int, | |
replicate_params: bool = True, | |
neg_prompt: str = '', | |
steps: int = 50, | |
cfg: float = 9.0, | |
unload_after_usage: bool = False | |
) -> List[Image.Image]: | |
assert jax.process_index() == 0, 'not main process' | |
self.log('Sample') | |
self._load_inference_models() | |
with torch.no_grad(): | |
tokens = self.tokenizer( | |
[ prompt ], | |
truncation = True, | |
return_overflowing_tokens = False, | |
padding = 'max_length', | |
return_tensors = 'pt' | |
).input_ids | |
neg_tokens = self.tokenizer( | |
[ neg_prompt ], | |
truncation = True, | |
return_overflowing_tokens = False, | |
padding = 'max_length', | |
return_tensors = 'pt' | |
).input_ids | |
encoded_prompt = self.text_encoder(input_ids = tokens).last_hidden_state | |
encoded_neg_prompt = self.text_encoder(input_ids = neg_tokens).last_hidden_state | |
hint_latent = torch.tensor(np.asarray(Image.open(image_path))).permute(2,0,1).to(torch.float32).div(255).mul(2).sub(1).unsqueeze(0) | |
hint_latent = self.vae.encode(hint_latent).latent_dist.mean * self.vae.config.scaling_factor #0.18215 # deterministic | |
hint_latent = hint_latent.unsqueeze(2).repeat_interleave(num_frames, 2) | |
mask = torch.zeros_like(hint_latent[:,0:1,:,:,:]) # zero mask, e.g. skip masking for now | |
init_latent = torch.randn_like(hint_latent) | |
# move to devices | |
encoded_prompt = jnp.array(encoded_prompt.numpy()) | |
encoded_neg_prompt = jnp.array(encoded_neg_prompt.numpy()) | |
hint_latent = jnp.array(hint_latent.numpy()) | |
mask = jnp.array(mask.numpy()) | |
init_latent = init_latent.repeat(jax.device_count(), 1, 1, 1, 1) | |
init_latent = jnp.array(init_latent.numpy()) | |
self.ddim = (self.ddim[0], self.ddim[0].set_timesteps(self.ddim[1], steps)) | |
timesteps = self.ddim[1].timesteps | |
if replicate_params: | |
params = jax_utils.replicate(params) | |
ddim_state = jax_utils.replicate(self.ddim[1]) | |
encoded_prompt = jax_utils.replicate(encoded_prompt) | |
encoded_neg_prompt = jax_utils.replicate(encoded_neg_prompt) | |
hint_latent = jax_utils.replicate(hint_latent) | |
mask = jax_utils.replicate(mask) | |
# sampling fun | |
def sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask): | |
latent_model_input = jnp.concatenate([init_latent, mask, hint_latent], axis = 1) | |
pred = self.model.apply( | |
{ 'params': params }, | |
latent_model_input, | |
t, | |
encoded_prompt | |
).sample | |
if cfg != 1.0: | |
neg_pred = self.model.apply( | |
{ 'params': params }, | |
latent_model_input, | |
t, | |
encoded_neg_prompt | |
).sample | |
pred = neg_pred + cfg * (pred - neg_pred) | |
# TODO check if noise is added at the right dimension | |
init_latent, ddim_state = self.ddim[0].step(ddim_state, pred, t, init_latent).to_tuple() | |
return init_latent, ddim_state | |
p_sample_loop = jax.pmap(sample_loop, 'sample', donate_argnums = ()) | |
pbar_sample = trange(len(timesteps), desc = 'Sample', dynamic_ncols = True, smoothing = 0.1, disable = not self.verbose) | |
init_latent = shard(init_latent) | |
for i in pbar_sample: | |
t = timesteps[i].repeat(self.num_devices) | |
t = shard(t) | |
init_latent, ddim_state = p_sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask) | |
# decode | |
self.log('Decode') | |
init_latent = torch.tensor(np.array(init_latent)) | |
init_latent = init_latent / self.vae.config.scaling_factor | |
# d:0 b:1 c:2 f:3 h:4 w:5 -> d b f c h w | |
init_latent = init_latent.permute(0, 1, 3, 2, 4, 5) | |
images = [] | |
pbar_decode = trange(len(init_latent), desc = 'Decode', dynamic_ncols = True) | |
for sample in init_latent: | |
ims = self.vae.decode(sample.squeeze()).sample | |
ims = ims.add(1).div(2).mul(255).round().clamp(0, 255).to(torch.uint8).permute(0,2,3,1).numpy() | |
ims = [ Image.fromarray(x) for x in ims ] | |
for im in ims: | |
images.append(im) | |
pbar_decode.update(1) | |
if unload_after_usage: | |
self._unload_inference_models() | |
return images | |
def get_params_from_state(self, state: TrainState) -> FrozenDict[Any, str]: | |
return FrozenDict(jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))) | |
def train(self, | |
dataloader: DataLoader, | |
lr: float, | |
num_frames: int, | |
log_every_step: int = 10, | |
save_every_epoch: int = 1, | |
sample_every_epoch: int = 1, | |
output_dir: str = 'output', | |
warmup: float = 0, | |
decay: float = 0, | |
epochs: int = 10, | |
weight_decay: float = 1e-2 | |
) -> None: | |
eps = 1e-8 | |
total_steps = len(dataloader) * epochs | |
warmup_steps = math.ceil(warmup * total_steps) if warmup > 0 else 0 | |
decay_steps = math.ceil(decay * total_steps) + warmup_steps if decay > 0 else warmup_steps + 1 | |
self.log(f'Total steps: {total_steps}') | |
self.log(f'Warmup steps: {warmup_steps}') | |
self.log(f'Decay steps: {decay_steps - warmup_steps}') | |
if warmup > 0 or decay > 0: | |
if not decay > 0: | |
# only warmup, keep peak lr until end | |
self.log('Warmup schedule') | |
end_lr = lr | |
else: | |
# warmup + annealing to end lr | |
self.log('Warmup + cosine annealing schedule') | |
end_lr = eps | |
lr_schedule = optax.warmup_cosine_decay_schedule( | |
init_value = 0.0, | |
peak_value = lr, | |
warmup_steps = warmup_steps, | |
decay_steps = decay_steps, | |
end_value = end_lr | |
) | |
else: | |
# no warmup or decay -> constant lr | |
self.log('constant schedule') | |
lr_schedule = optax.constant_schedule(value = lr) | |
adamw = optax.adamw( | |
learning_rate = lr_schedule, | |
b1 = 0.9, | |
b2 = 0.999, | |
eps = eps, | |
weight_decay = weight_decay #0.01 # 0.0001 | |
) | |
optim = optax.chain( | |
optax.clip_by_global_norm(max_norm = 1.0), | |
adamw | |
) | |
partition_optimizers = { | |
'trainable': optim, | |
'frozen': optax.set_to_zero() | |
} | |
tx = optax.multi_transform(partition_optimizers, self.param_partitions) | |
state = TrainState.create( | |
apply_fn = self.model.__call__, | |
params = self.params, | |
tx = tx | |
) | |
validation_rng, train_rngs = jax.random.split(self.rng) | |
train_rngs = jax.random.split(train_rngs, jax.local_device_count()) | |
def train_step(state: TrainState, batch: Dict[str, jax.Array], train_rng: jax.random.PRNGKeyArray): | |
def compute_loss( | |
params: Dict[str, Any], | |
batch: Dict[str, jax.Array], | |
sample_rng: jax.random.PRNGKeyArray # unused, dataloader provides everything | |
) -> jax.Array: | |
# 'latent_model_input': latent_model_input | |
# 'encoder_hidden_states': encoder_hidden_states | |
# 'timesteps': timesteps | |
# 'noise': noise | |
latent_model_input = batch['latent_model_input'] | |
encoder_hidden_states = batch['encoder_hidden_states'] | |
timesteps = batch['timesteps'] | |
noise = batch['noise'] | |
model_pred = self.model.apply( | |
{ 'params': params }, | |
latent_model_input, | |
timesteps, | |
encoder_hidden_states | |
).sample | |
loss = (noise - model_pred) ** 2 | |
loss = loss.mean() | |
return loss | |
grad_fn = jax.value_and_grad(compute_loss) | |
def loss_and_grad( | |
train_rng: jax.random.PRNGKeyArray | |
) -> Tuple[jax.Array, Any, jax.random.PRNGKeyArray]: | |
sample_rng, train_rng = jax.random.split(train_rng, 2) | |
loss, grad = grad_fn(state.params, batch, sample_rng) | |
return loss, grad, train_rng | |
loss, grad, new_train_rng = loss_and_grad(train_rng) | |
# self.log(grad) # NOTE uncomment to visualize gradient | |
grad = jax.lax.pmean(grad, axis_name = 'batch') | |
new_state = state.apply_gradients(grads = grad) | |
metrics: Dict[str, Any] = { 'loss': loss } | |
metrics = jax.lax.pmean(metrics, axis_name = 'batch') | |
def l2(xs) -> jax.Array: | |
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)])) | |
metrics['l2_grads'] = l2(jax.tree_util.tree_leaves(grad)) | |
return new_state, metrics, new_train_rng | |
p_train_step = jax.pmap(fun = train_step, axis_name = 'batch', donate_argnums = (0, )) | |
state = jax_utils.replicate(state) | |
train_metrics = [] | |
train_metric = None | |
global_step: int = 0 | |
if jax.process_index() == 0: | |
self._init_tracker_meta() | |
hyper_params = { | |
'lr': lr, | |
'lr_warmup': warmup, | |
'lr_decay': decay, | |
'weight_decay': weight_decay, | |
'total_steps': total_steps, | |
'batch_size': dataloader.batch_size // self.num_devices, | |
'num_frames': num_frames, | |
'sample_size': self.sample_size, | |
'num_devices': self.num_devices, | |
'seed': self.seed, | |
'use_memory_efficient_attention': self.model.use_memory_efficient_attention, | |
'only_temporal': self.only_temporal, | |
'dtype': self.dtype_str, | |
'param_dtype': self.param_dtype, | |
'pretrained_model': self.pretrained_model, | |
'model_config': self.model.config | |
} | |
if self._use_wandb: | |
self.log('Setting up wandb') | |
self._setup_wandb(hyper_params) | |
self.log(hyper_params) | |
output_path = os.path.join(output_dir, str(global_step), 'unet') | |
self.log(f'saving checkpoint to {output_path}') | |
self.model.save_pretrained( | |
save_directory = output_path, | |
params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)), | |
is_main_process = True | |
) | |
pbar_epoch = tqdm( | |
total = epochs, | |
desc = 'Epochs', | |
smoothing = 1, | |
position = 0, | |
dynamic_ncols = True, | |
leave = True, | |
disable = jax.process_index() > 0 | |
) | |
steps_per_epoch = len(dataloader) # TODO dataloader | |
for epoch in range(epochs): | |
pbar_steps = tqdm( | |
total = steps_per_epoch, | |
desc = 'Steps', | |
position = 1, | |
smoothing = 0.1, | |
dynamic_ncols = True, | |
leave = True, | |
disable = jax.process_index() > 0 | |
) | |
for batch in dataloader: | |
# keep input + gt as float32, results in fp32 loss and grad | |
# otherwise uncomment the following to cast to the model dtype | |
# batch = { k: (v.astype(self.dtype) if v.dtype == np.float32 else v) for k,v in batch.items() } | |
batch = shard(batch) | |
state, train_metric, train_rngs = p_train_step( | |
state, batch, train_rngs | |
) | |
train_metrics.append(train_metric) | |
if global_step % log_every_step == 0 and jax.process_index() == 0: | |
train_metrics = jax_utils.unreplicate(train_metrics) | |
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics) | |
if global_step == 0: | |
self.log(f'grad dtype: {train_metrics["l2_grads"].dtype}') | |
self.log(f'loss dtype: {train_metrics["loss"].dtype}') | |
train_metrics_dict = { k: v.item() for k, v in train_metrics.items() } | |
train_metrics_dict['lr'] = lr_schedule(global_step).item() | |
self.log_metrics(train_metrics_dict, step = global_step, epoch = epoch) | |
train_metrics = [] | |
pbar_steps.update(1) | |
global_step += 1 | |
if epoch % save_every_epoch == 0 and jax.process_index() == 0: | |
output_path = os.path.join(output_dir, str(global_step), 'unet') | |
self.log(f'saving checkpoint to {output_path}') | |
self.model.save_pretrained( | |
save_directory = output_path, | |
params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)), | |
is_main_process = True | |
) | |
self.log(f'checkpoint saved ') | |
if epoch % sample_every_epoch == 0 and jax.process_index() == 0: | |
images = self.sample( | |
params = state.params, | |
replicate_params = False, | |
prompt = 'dancing person', | |
image_path = 'testimage.png', | |
num_frames = num_frames, | |
steps = 50, | |
cfg = 9.0, | |
unload_after_usage = False | |
) | |
os.makedirs(os.path.join('image_output', str(epoch)), exist_ok = True) | |
for i, im in enumerate(images): | |
im.save(os.path.join('image_output', str(epoch), str(i).zfill(5) + '.png'), optimize = True) | |
pbar_epoch.update(1) | |