Spaces:
Runtime error
Runtime error
File size: 25,717 Bytes
149cc2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 |
from typing import Any, Optional, Union, Tuple, Dict, List
import os
import random
import math
import time
import numpy as np
from tqdm.auto import tqdm, trange
import torch
from torch.utils.data import DataLoader
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
from flax.core.frozen_dict import FrozenDict
from flax.training.train_state import TrainState
from flax.training.common_utils import shard
# convert 2D -> 3D
from diffusers import FlaxUNet2DConditionModel
# inference test, run on these on cpu
from diffusers import AutoencoderKL
from diffusers.schedulers.scheduling_ddim_flax import FlaxDDIMScheduler, DDIMSchedulerState
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
from .flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
def seed_all(seed: int) -> jax.random.PRNGKeyArray:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
rng = jax.random.PRNGKey(seed)
return rng
def count_params(
params: Union[Dict[str, Any],
FrozenDict[str, Any]],
filter_name: Optional[str] = None
) -> int:
p: Dict[Tuple[str], jax.Array] = traverse_util.flatten_dict(params)
cc = 0
for k in p:
if filter_name is not None:
if filter_name in ' '.join(k):
cc += len(p[k].flatten())
else:
cc += len(p[k].flatten())
return cc
def map_2d_to_pseudo3d(
params2d: Dict[str, Any],
params3d: Dict[str, Any],
verbose: bool = True
) -> Dict[str, Any]:
params2d = traverse_util.flatten_dict(params2d)
params3d = traverse_util.flatten_dict(params3d)
new_params = dict()
for k in params3d:
if 'spatial_conv' in k:
k2d = list(k)
k2d.remove('spatial_conv')
k2d = tuple(k2d)
if verbose:
tqdm.write(f'Spatial: {k} <- {k2d}')
p = params2d[k2d]
elif k not in params2d:
if verbose:
tqdm.write(f'Missing: {k}')
p = params3d[k]
else:
p = params2d[k]
assert p.shape == params3d[k].shape, f'shape mismatch: {k}: {p.shape} != {params3d[k].shape}'
new_params[k] = p
new_params = traverse_util.unflatten_dict(new_params)
return new_params
class FlaxTrainerUNetPseudo3D:
def __init__(self,
model_path: str,
from_pt: bool = True,
convert2d: bool = False,
sample_size: Tuple[int, int] = (64, 64),
seed: int = 0,
dtype: str = 'float32',
param_dtype: str = 'float32',
only_temporal: bool = True,
use_memory_efficient_attention = False,
verbose: bool = True
) -> None:
self.verbose = verbose
self.tracker: Optional['wandb.sdk.wandb_run.Run'] = None
self._use_wandb: bool = False
self._tracker_meta: Dict[str, Union[float, int]] = {
't00': 0.0,
't0': 0.0,
'step0': 0
}
self.log('Init JAX')
self.num_devices = jax.device_count()
self.log(f'Device count: {self.num_devices}')
self.seed = seed
self.rng: jax.random.PRNGKeyArray = seed_all(self.seed)
self.sample_size = sample_size
if dtype == 'float32':
self.dtype = jnp.float32
elif dtype == 'bfloat16':
self.dtype = jnp.bfloat16
elif dtype == 'float16':
self.dtype = jnp.float16
else:
raise ValueError(f'unknown type: {dtype}')
self.dtype_str: str = dtype
if param_dtype not in ['float32', 'bfloat16', 'float16']:
raise ValueError(f'unknown parameter type: {param_dtype}')
self.param_dtype = param_dtype
self._load_models(
model_path = model_path,
convert2d = convert2d,
from_pt = from_pt,
use_memory_efficient_attention = use_memory_efficient_attention
)
self._mark_parameters(only_temporal = only_temporal)
# optionally for validation + sampling
self.tokenizer: Optional[CLIPTokenizer] = None
self.text_encoder: Optional[CLIPTextModel] = None
self.vae: Optional[AutoencoderKL] = None
self.ddim: Optional[Tuple[FlaxDDIMScheduler, DDIMSchedulerState]] = None
def log(self, message: Any) -> None:
if self.verbose and jax.process_index() == 0:
tqdm.write(str(message))
def log_metrics(self, metrics: dict, step: int, epoch: int) -> None:
if jax.process_index() > 0 or (not self.verbose and self.tracker is None):
return
now = time.monotonic()
log_data = {
'train/step': step,
'train/epoch': epoch,
'train/steps_per_sec': (step - self._tracker_meta['step0']) / (now - self._tracker_meta['t0']),
**{ f'train/{k}': v for k, v in metrics.items() }
}
self._tracker_meta['t0'] = now
self._tracker_meta['step0'] = step
self.log(log_data)
if self.tracker is not None:
self.tracker.log(log_data, step = step)
def enable_wandb(self, enable: bool = True) -> None:
self._use_wandb = enable
def _setup_wandb(self, config: Dict[str, Any] = dict()) -> None:
import wandb
import wandb.sdk
self.tracker: wandb.sdk.wandb_run.Run = wandb.init(
config = config,
settings = wandb.sdk.Settings(
username = 'anon',
host = 'anon',
email = 'anon',
root_dir = 'anon',
_executable = 'anon',
_disable_stats = True,
_disable_meta = True,
disable_code = True,
disable_git = True
) # pls don't log sensitive data like system user names. also, fuck you for even trying.
)
def _init_tracker_meta(self) -> None:
now = time.monotonic()
self._tracker_meta = {
't00': now,
't0': now,
'step0': 0
}
def _load_models(self,
model_path: str,
convert2d: bool,
from_pt: bool,
use_memory_efficient_attention: bool
) -> None:
self.log(f'Load pretrained from {model_path}')
if convert2d:
self.log(' Convert 2D model to Pseudo3D')
self.log(' Initiate Pseudo3D model')
config = UNetPseudo3DConditionModel.load_config(model_path, subfolder = 'unet')
model = UNetPseudo3DConditionModel.from_config(
config,
sample_size = self.sample_size,
dtype = self.dtype,
param_dtype = self.param_dtype,
use_memory_efficient_attention = use_memory_efficient_attention
)
params: Dict[str, Any] = model.init_weights(self.rng).unfreeze()
self.log(' Load 2D model')
model2d, params2d = FlaxUNet2DConditionModel.from_pretrained(
model_path,
subfolder = 'unet',
dtype = self.dtype,
from_pt = from_pt
)
self.log(' Map 2D -> 3D')
params = map_2d_to_pseudo3d(params2d, params, verbose = self.verbose)
del params2d
del model2d
del config
else:
model, params = UNetPseudo3DConditionModel.from_pretrained(
model_path,
subfolder = 'unet',
from_pt = from_pt,
sample_size = self.sample_size,
dtype = self.dtype,
param_dtype = self.param_dtype,
use_memory_efficient_attention = use_memory_efficient_attention
)
self.log(f'Cast parameters to {model.param_dtype}')
if model.param_dtype == 'float32':
params = model.to_fp32(params)
elif model.param_dtype == 'float16':
params = model.to_fp16(params)
elif model.param_dtype == 'bfloat16':
params = model.to_bf16(params)
self.pretrained_model = model_path
self.model: UNetPseudo3DConditionModel = model
self.params: FrozenDict[str, Any] = FrozenDict(params)
def _mark_parameters(self, only_temporal: bool) -> None:
self.log('Mark training parameters')
if only_temporal:
self.log('Only training temporal layers')
if only_temporal:
param_partitions = traverse_util.path_aware_map(
lambda path, _: 'trainable' if 'temporal' in ' '.join(path) else 'frozen', self.params
)
else:
param_partitions = traverse_util.path_aware_map(
lambda *_: 'trainable', self.params
)
self.only_temporal = only_temporal
self.param_partitions: FrozenDict[str, Any] = FrozenDict(param_partitions)
self.log(f'Total parameters: {count_params(self.params)}')
self.log(f'Temporal parameters: {count_params(self.params, "temporal")}')
def _load_inference_models(self) -> None:
assert jax.process_index() == 0, 'not main process'
if self.text_encoder is None:
self.log('Load text encoder')
self.text_encoder = CLIPTextModel.from_pretrained(
self.pretrained_model,
subfolder = 'text_encoder'
)
if self.tokenizer is None:
self.log('Load tokenizer')
self.tokenizer = CLIPTokenizer.from_pretrained(
self.pretrained_model,
subfolder = 'tokenizer'
)
if self.vae is None:
self.log('Load vae')
self.vae = AutoencoderKL.from_pretrained(
self.pretrained_model,
subfolder = 'vae'
)
if self.ddim is None:
self.log('Load ddim scheduler')
# tuple(scheduler , scheduler state)
self.ddim = FlaxDDIMScheduler.from_pretrained(
self.pretrained_model,
subfolder = 'scheduler',
from_pt = True
)
def _unload_inference_models(self) -> None:
self.text_encoder = None
self.tokenizer = None
self.vae = None
self.ddim = None
def sample(self,
params: Union[Dict[str, Any], FrozenDict[str, Any]],
prompt: str,
image_path: str,
num_frames: int,
replicate_params: bool = True,
neg_prompt: str = '',
steps: int = 50,
cfg: float = 9.0,
unload_after_usage: bool = False
) -> List[Image.Image]:
assert jax.process_index() == 0, 'not main process'
self.log('Sample')
self._load_inference_models()
with torch.no_grad():
tokens = self.tokenizer(
[ prompt ],
truncation = True,
return_overflowing_tokens = False,
padding = 'max_length',
return_tensors = 'pt'
).input_ids
neg_tokens = self.tokenizer(
[ neg_prompt ],
truncation = True,
return_overflowing_tokens = False,
padding = 'max_length',
return_tensors = 'pt'
).input_ids
encoded_prompt = self.text_encoder(input_ids = tokens).last_hidden_state
encoded_neg_prompt = self.text_encoder(input_ids = neg_tokens).last_hidden_state
hint_latent = torch.tensor(np.asarray(Image.open(image_path))).permute(2,0,1).to(torch.float32).div(255).mul(2).sub(1).unsqueeze(0)
hint_latent = self.vae.encode(hint_latent).latent_dist.mean * self.vae.config.scaling_factor #0.18215 # deterministic
hint_latent = hint_latent.unsqueeze(2).repeat_interleave(num_frames, 2)
mask = torch.zeros_like(hint_latent[:,0:1,:,:,:]) # zero mask, e.g. skip masking for now
init_latent = torch.randn_like(hint_latent)
# move to devices
encoded_prompt = jnp.array(encoded_prompt.numpy())
encoded_neg_prompt = jnp.array(encoded_neg_prompt.numpy())
hint_latent = jnp.array(hint_latent.numpy())
mask = jnp.array(mask.numpy())
init_latent = init_latent.repeat(jax.device_count(), 1, 1, 1, 1)
init_latent = jnp.array(init_latent.numpy())
self.ddim = (self.ddim[0], self.ddim[0].set_timesteps(self.ddim[1], steps))
timesteps = self.ddim[1].timesteps
if replicate_params:
params = jax_utils.replicate(params)
ddim_state = jax_utils.replicate(self.ddim[1])
encoded_prompt = jax_utils.replicate(encoded_prompt)
encoded_neg_prompt = jax_utils.replicate(encoded_neg_prompt)
hint_latent = jax_utils.replicate(hint_latent)
mask = jax_utils.replicate(mask)
# sampling fun
def sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask):
latent_model_input = jnp.concatenate([init_latent, mask, hint_latent], axis = 1)
pred = self.model.apply(
{ 'params': params },
latent_model_input,
t,
encoded_prompt
).sample
if cfg != 1.0:
neg_pred = self.model.apply(
{ 'params': params },
latent_model_input,
t,
encoded_neg_prompt
).sample
pred = neg_pred + cfg * (pred - neg_pred)
# TODO check if noise is added at the right dimension
init_latent, ddim_state = self.ddim[0].step(ddim_state, pred, t, init_latent).to_tuple()
return init_latent, ddim_state
p_sample_loop = jax.pmap(sample_loop, 'sample', donate_argnums = ())
pbar_sample = trange(len(timesteps), desc = 'Sample', dynamic_ncols = True, smoothing = 0.1, disable = not self.verbose)
init_latent = shard(init_latent)
for i in pbar_sample:
t = timesteps[i].repeat(self.num_devices)
t = shard(t)
init_latent, ddim_state = p_sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask)
# decode
self.log('Decode')
init_latent = torch.tensor(np.array(init_latent))
init_latent = init_latent / self.vae.config.scaling_factor
# d:0 b:1 c:2 f:3 h:4 w:5 -> d b f c h w
init_latent = init_latent.permute(0, 1, 3, 2, 4, 5)
images = []
pbar_decode = trange(len(init_latent), desc = 'Decode', dynamic_ncols = True)
for sample in init_latent:
ims = self.vae.decode(sample.squeeze()).sample
ims = ims.add(1).div(2).mul(255).round().clamp(0, 255).to(torch.uint8).permute(0,2,3,1).numpy()
ims = [ Image.fromarray(x) for x in ims ]
for im in ims:
images.append(im)
pbar_decode.update(1)
if unload_after_usage:
self._unload_inference_models()
return images
def get_params_from_state(self, state: TrainState) -> FrozenDict[Any, str]:
return FrozenDict(jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)))
def train(self,
dataloader: DataLoader,
lr: float,
num_frames: int,
log_every_step: int = 10,
save_every_epoch: int = 1,
sample_every_epoch: int = 1,
output_dir: str = 'output',
warmup: float = 0,
decay: float = 0,
epochs: int = 10,
weight_decay: float = 1e-2
) -> None:
eps = 1e-8
total_steps = len(dataloader) * epochs
warmup_steps = math.ceil(warmup * total_steps) if warmup > 0 else 0
decay_steps = math.ceil(decay * total_steps) + warmup_steps if decay > 0 else warmup_steps + 1
self.log(f'Total steps: {total_steps}')
self.log(f'Warmup steps: {warmup_steps}')
self.log(f'Decay steps: {decay_steps - warmup_steps}')
if warmup > 0 or decay > 0:
if not decay > 0:
# only warmup, keep peak lr until end
self.log('Warmup schedule')
end_lr = lr
else:
# warmup + annealing to end lr
self.log('Warmup + cosine annealing schedule')
end_lr = eps
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value = 0.0,
peak_value = lr,
warmup_steps = warmup_steps,
decay_steps = decay_steps,
end_value = end_lr
)
else:
# no warmup or decay -> constant lr
self.log('constant schedule')
lr_schedule = optax.constant_schedule(value = lr)
adamw = optax.adamw(
learning_rate = lr_schedule,
b1 = 0.9,
b2 = 0.999,
eps = eps,
weight_decay = weight_decay #0.01 # 0.0001
)
optim = optax.chain(
optax.clip_by_global_norm(max_norm = 1.0),
adamw
)
partition_optimizers = {
'trainable': optim,
'frozen': optax.set_to_zero()
}
tx = optax.multi_transform(partition_optimizers, self.param_partitions)
state = TrainState.create(
apply_fn = self.model.__call__,
params = self.params,
tx = tx
)
validation_rng, train_rngs = jax.random.split(self.rng)
train_rngs = jax.random.split(train_rngs, jax.local_device_count())
def train_step(state: TrainState, batch: Dict[str, jax.Array], train_rng: jax.random.PRNGKeyArray):
def compute_loss(
params: Dict[str, Any],
batch: Dict[str, jax.Array],
sample_rng: jax.random.PRNGKeyArray # unused, dataloader provides everything
) -> jax.Array:
# 'latent_model_input': latent_model_input
# 'encoder_hidden_states': encoder_hidden_states
# 'timesteps': timesteps
# 'noise': noise
latent_model_input = batch['latent_model_input']
encoder_hidden_states = batch['encoder_hidden_states']
timesteps = batch['timesteps']
noise = batch['noise']
model_pred = self.model.apply(
{ 'params': params },
latent_model_input,
timesteps,
encoder_hidden_states
).sample
loss = (noise - model_pred) ** 2
loss = loss.mean()
return loss
grad_fn = jax.value_and_grad(compute_loss)
def loss_and_grad(
train_rng: jax.random.PRNGKeyArray
) -> Tuple[jax.Array, Any, jax.random.PRNGKeyArray]:
sample_rng, train_rng = jax.random.split(train_rng, 2)
loss, grad = grad_fn(state.params, batch, sample_rng)
return loss, grad, train_rng
loss, grad, new_train_rng = loss_and_grad(train_rng)
# self.log(grad) # NOTE uncomment to visualize gradient
grad = jax.lax.pmean(grad, axis_name = 'batch')
new_state = state.apply_gradients(grads = grad)
metrics: Dict[str, Any] = { 'loss': loss }
metrics = jax.lax.pmean(metrics, axis_name = 'batch')
def l2(xs) -> jax.Array:
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
metrics['l2_grads'] = l2(jax.tree_util.tree_leaves(grad))
return new_state, metrics, new_train_rng
p_train_step = jax.pmap(fun = train_step, axis_name = 'batch', donate_argnums = (0, ))
state = jax_utils.replicate(state)
train_metrics = []
train_metric = None
global_step: int = 0
if jax.process_index() == 0:
self._init_tracker_meta()
hyper_params = {
'lr': lr,
'lr_warmup': warmup,
'lr_decay': decay,
'weight_decay': weight_decay,
'total_steps': total_steps,
'batch_size': dataloader.batch_size // self.num_devices,
'num_frames': num_frames,
'sample_size': self.sample_size,
'num_devices': self.num_devices,
'seed': self.seed,
'use_memory_efficient_attention': self.model.use_memory_efficient_attention,
'only_temporal': self.only_temporal,
'dtype': self.dtype_str,
'param_dtype': self.param_dtype,
'pretrained_model': self.pretrained_model,
'model_config': self.model.config
}
if self._use_wandb:
self.log('Setting up wandb')
self._setup_wandb(hyper_params)
self.log(hyper_params)
output_path = os.path.join(output_dir, str(global_step), 'unet')
self.log(f'saving checkpoint to {output_path}')
self.model.save_pretrained(
save_directory = output_path,
params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
is_main_process = True
)
pbar_epoch = tqdm(
total = epochs,
desc = 'Epochs',
smoothing = 1,
position = 0,
dynamic_ncols = True,
leave = True,
disable = jax.process_index() > 0
)
steps_per_epoch = len(dataloader) # TODO dataloader
for epoch in range(epochs):
pbar_steps = tqdm(
total = steps_per_epoch,
desc = 'Steps',
position = 1,
smoothing = 0.1,
dynamic_ncols = True,
leave = True,
disable = jax.process_index() > 0
)
for batch in dataloader:
# keep input + gt as float32, results in fp32 loss and grad
# otherwise uncomment the following to cast to the model dtype
# batch = { k: (v.astype(self.dtype) if v.dtype == np.float32 else v) for k,v in batch.items() }
batch = shard(batch)
state, train_metric, train_rngs = p_train_step(
state, batch, train_rngs
)
train_metrics.append(train_metric)
if global_step % log_every_step == 0 and jax.process_index() == 0:
train_metrics = jax_utils.unreplicate(train_metrics)
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
if global_step == 0:
self.log(f'grad dtype: {train_metrics["l2_grads"].dtype}')
self.log(f'loss dtype: {train_metrics["loss"].dtype}')
train_metrics_dict = { k: v.item() for k, v in train_metrics.items() }
train_metrics_dict['lr'] = lr_schedule(global_step).item()
self.log_metrics(train_metrics_dict, step = global_step, epoch = epoch)
train_metrics = []
pbar_steps.update(1)
global_step += 1
if epoch % save_every_epoch == 0 and jax.process_index() == 0:
output_path = os.path.join(output_dir, str(global_step), 'unet')
self.log(f'saving checkpoint to {output_path}')
self.model.save_pretrained(
save_directory = output_path,
params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
is_main_process = True
)
self.log(f'checkpoint saved ')
if epoch % sample_every_epoch == 0 and jax.process_index() == 0:
images = self.sample(
params = state.params,
replicate_params = False,
prompt = 'dancing person',
image_path = 'testimage.png',
num_frames = num_frames,
steps = 50,
cfg = 9.0,
unload_after_usage = False
)
os.makedirs(os.path.join('image_output', str(epoch)), exist_ok = True)
for i, im in enumerate(images):
im.save(os.path.join('image_output', str(epoch), str(i).zfill(5) + '.png'), optimize = True)
pbar_epoch.update(1)
|