File size: 7,516 Bytes
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import torch
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter
from dataclasses import dataclass, field
from typing import Optional, Union
from datasets import load_dataset
import json
import abc
from diffusers.utils import make_image_grid
import numpy as np
import wandb

from custum_3d_diffusion.trainings.utils import load_config
from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig

class BasicTrainer(torch.nn.Module, abc.ABC):
    accelerator: Accelerator
    logger: MultiProcessAdapter
    unet: ConfigurableUNet2DConditionModel
    train_dataloader: torch.utils.data.DataLoader
    test_dataset: torch.utils.data.Dataset
    attn_config: AttnConfig
    
    @dataclass
    class TrainerConfig:
        trainer_name: str = "basic"
        pretrained_model_name_or_path: str = ""
        
        attn_config: dict = field(default_factory=dict)
        dataset_name: str = ""
        dataset_config_name: Optional[str] = None
        resolution: str = "1024"
        dataloader_num_workers: int = 4
        pair_sampler_group_size: int = 1
        num_views: int = 4
        
        max_train_steps: int = -1                       # -1 means infinity, otherwise [0, max_train_steps)
        training_step_interval: int = 1                 # train on step i*interval, stop at max_train_steps
        max_train_samples: Optional[int] = None
        seed: Optional[int] = None                      # For dataset related operations and validation stuff
        train_batch_size: int = 1
        
        validation_interval: int = 5000
        debug: bool = False
    
    cfg: TrainerConfig    # only enable_xxx is used
    
    def __init__(
        self, 
        accelerator: Accelerator, 
        logger: MultiProcessAdapter,
        unet: ConfigurableUNet2DConditionModel,
        config: Union[dict, str],
        weight_dtype: torch.dtype,
        index: int,
    ):
        super().__init__()
        self.index = index              # index in all trainers
        self.accelerator = accelerator
        self.logger = logger
        self.unet = unet
        self.weight_dtype = weight_dtype
        self.ext_logs = {}
        self.cfg = load_config(self.TrainerConfig, config)
        self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
        self.test_dataset = None
        self.validate_trainer_config()
        self.configure()
    
    def get_HW(self):
        resolution = json.loads(self.cfg.resolution)
        if isinstance(resolution, int):
            H = W = resolution
        elif isinstance(resolution, list):
            H, W = resolution
        return H, W
    
    def unet_update(self):
        self.unet.update_config(self.attn_config)
    
    def validate_trainer_config(self):
        pass
    
    def is_train_finished(self, current_step):
        assert isinstance(self.cfg.max_train_steps, int)
        return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
    
    def next_train_step(self, current_step):
        if self.is_train_finished(current_step):
            return None
        return current_step + self.cfg.training_step_interval

    @classmethod
    def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
        catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
        return make_image_grid(catted, rows=1, cols=len(catted))

    def configure(self) -> None:
        pass
    
    @abc.abstractmethod
    def init_shared_modules(self, shared_modules: dict) -> dict:
        pass
    
    def load_dataset(self):
        dataset = load_dataset(
            self.cfg.dataset_name,
            self.cfg.dataset_config_name,
            trust_remote_code=True
        )
        return dataset

    @abc.abstractmethod
    def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
        """Both init train_dataloader and test_dataset, but returns train_dataloader only"""
        pass
    
    @abc.abstractmethod
    def forward_step(
        self, 
        *args, 
        **kwargs
    ) -> torch.Tensor:
        """
        input a batch
        return a loss
        """
        self.unet_update()
        pass
    
    @abc.abstractmethod
    def construct_pipeline(self, shared_modules, unet):
        pass
    
    @abc.abstractmethod
    def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
        """
            For inference time forward.
        """
        pass

    @abc.abstractmethod
    def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
        pass

    def do_validation(
        self,
        shared_modules,
        unet,
        global_step,
    ):
        self.unet_update()
        self.logger.info("Running validation... ")
        pipeline = self.construct_pipeline(shared_modules, unet)
        pipeline.set_progress_bar_config(disable=True)
        titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
        for tracker in self.accelerator.trackers:
            if tracker.name == "tensorboard":
                np_images = np.stack([np.asarray(img) for img in images])
                tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
            elif tracker.name == "wandb":
                [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title]   # inplace operation
                tracker.log({"validation": [
                    wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
                    for i, image in enumerate(images)]})
            else:
                self.logger.warn(f"image logging not implemented for {tracker.name}")
        del pipeline
        torch.cuda.empty_cache()
        return images

    
    @torch.no_grad()
    def log_validation(
        self,
        shared_modules,
        unet,
        global_step,
        force=False
    ):
        if self.accelerator.is_main_process:
            for tracker in self.accelerator.trackers:
                if tracker.name == "wandb":
                    tracker.log(self.ext_logs)
        self.ext_logs = {}
        if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
            self.unet_update()
            if self.accelerator.is_main_process:
                self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)

    def save_model(self, unwrap_unet, shared_modules, save_dir):
        if self.accelerator.is_main_process:
            pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
            pipeline.save_pretrained(save_dir)
            self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")

    def save_debug_info(self, save_name="debug", **kwargs):
        if self.cfg.debug:
            to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
            import pickle
            import os
            if os.path.exists(f"{save_name}.pkl"):
                for i in range(100):
                    if not os.path.exists(f"{save_name}_v{i}.pkl"):
                        save_name = f"{save_name}_v{i}"
                        break
            with open(f"{save_name}.pkl", "wb") as f:
                pickle.dump(to_saves, f)