Wuvin's picture
init
37aeb5b
raw
history blame
No virus
7.52 kB
import torch
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter
from dataclasses import dataclass, field
from typing import Optional, Union
from datasets import load_dataset
import json
import abc
from diffusers.utils import make_image_grid
import numpy as np
import wandb
from custum_3d_diffusion.trainings.utils import load_config
from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
class BasicTrainer(torch.nn.Module, abc.ABC):
accelerator: Accelerator
logger: MultiProcessAdapter
unet: ConfigurableUNet2DConditionModel
train_dataloader: torch.utils.data.DataLoader
test_dataset: torch.utils.data.Dataset
attn_config: AttnConfig
@dataclass
class TrainerConfig:
trainer_name: str = "basic"
pretrained_model_name_or_path: str = ""
attn_config: dict = field(default_factory=dict)
dataset_name: str = ""
dataset_config_name: Optional[str] = None
resolution: str = "1024"
dataloader_num_workers: int = 4
pair_sampler_group_size: int = 1
num_views: int = 4
max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
max_train_samples: Optional[int] = None
seed: Optional[int] = None # For dataset related operations and validation stuff
train_batch_size: int = 1
validation_interval: int = 5000
debug: bool = False
cfg: TrainerConfig # only enable_xxx is used
def __init__(
self,
accelerator: Accelerator,
logger: MultiProcessAdapter,
unet: ConfigurableUNet2DConditionModel,
config: Union[dict, str],
weight_dtype: torch.dtype,
index: int,
):
super().__init__()
self.index = index # index in all trainers
self.accelerator = accelerator
self.logger = logger
self.unet = unet
self.weight_dtype = weight_dtype
self.ext_logs = {}
self.cfg = load_config(self.TrainerConfig, config)
self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
self.test_dataset = None
self.validate_trainer_config()
self.configure()
def get_HW(self):
resolution = json.loads(self.cfg.resolution)
if isinstance(resolution, int):
H = W = resolution
elif isinstance(resolution, list):
H, W = resolution
return H, W
def unet_update(self):
self.unet.update_config(self.attn_config)
def validate_trainer_config(self):
pass
def is_train_finished(self, current_step):
assert isinstance(self.cfg.max_train_steps, int)
return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
def next_train_step(self, current_step):
if self.is_train_finished(current_step):
return None
return current_step + self.cfg.training_step_interval
@classmethod
def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
return make_image_grid(catted, rows=1, cols=len(catted))
def configure(self) -> None:
pass
@abc.abstractmethod
def init_shared_modules(self, shared_modules: dict) -> dict:
pass
def load_dataset(self):
dataset = load_dataset(
self.cfg.dataset_name,
self.cfg.dataset_config_name,
trust_remote_code=True
)
return dataset
@abc.abstractmethod
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
"""Both init train_dataloader and test_dataset, but returns train_dataloader only"""
pass
@abc.abstractmethod
def forward_step(
self,
*args,
**kwargs
) -> torch.Tensor:
"""
input a batch
return a loss
"""
self.unet_update()
pass
@abc.abstractmethod
def construct_pipeline(self, shared_modules, unet):
pass
@abc.abstractmethod
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
"""
For inference time forward.
"""
pass
@abc.abstractmethod
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
pass
def do_validation(
self,
shared_modules,
unet,
global_step,
):
self.unet_update()
self.logger.info("Running validation... ")
pipeline = self.construct_pipeline(shared_modules, unet)
pipeline.set_progress_bar_config(disable=True)
titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
for tracker in self.accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
elif tracker.name == "wandb":
[image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
tracker.log({"validation": [
wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
for i, image in enumerate(images)]})
else:
self.logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
return images
@torch.no_grad()
def log_validation(
self,
shared_modules,
unet,
global_step,
force=False
):
if self.accelerator.is_main_process:
for tracker in self.accelerator.trackers:
if tracker.name == "wandb":
tracker.log(self.ext_logs)
self.ext_logs = {}
if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
self.unet_update()
if self.accelerator.is_main_process:
self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
def save_model(self, unwrap_unet, shared_modules, save_dir):
if self.accelerator.is_main_process:
pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
pipeline.save_pretrained(save_dir)
self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
def save_debug_info(self, save_name="debug", **kwargs):
if self.cfg.debug:
to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
import pickle
import os
if os.path.exists(f"{save_name}.pkl"):
for i in range(100):
if not os.path.exists(f"{save_name}_v{i}.pkl"):
save_name = f"{save_name}_v{i}"
break
with open(f"{save_name}.pkl", "wb") as f:
pickle.dump(to_saves, f)