File size: 3,564 Bytes
ce13d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import contextlib
import os
import tempfile
from pathlib import Path
import torch

class MonolithicCheckpointSaver(Callback):
    """Save a monolithic checkpoint every N batches.

    Args:
        save_folder (str): Folder to save checkpoints to (can be a URI)
        batch_interval (int): Number of batches between checkpoints.
        filename (str): Filename to save checkpoints to.
        overwrite (bool): Whether to overwrite previous checkpoints.
        keep_optimizers (bool): Whether to save the optimizer state in the monolithic checkpoint.
    """

    def __init__(self, save_folder: str, batch_interval: int, filename: str='ep{epoch}-ba{batch}.pt', overwrite: bool=False, keep_optimizers: bool=False):
        self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(save_folder)
        self.filename_format_str = filename
        self.batch_interval = batch_interval
        self.upload_to_object_store = self.backend != ''
        self.overwrite = overwrite
        self.keep_optimizers = keep_optimizers
        if self.upload_to_object_store:
            self.remote_ud = RemoteUploaderDownloader(bucket_uri=f'{self.backend}://{self.bucket_name}')
        else:
            self.remote_ud = None

    def init(self, state: State, logger: Logger) -> None:
        if self.upload_to_object_store and self.remote_ud is not None:
            self.remote_ud.init(state, logger)
            state.callbacks.append(self.remote_ud)

    def batch_checkpoint(self, state: State, logger: Logger) -> None:
        if state.timestamp.batch.value % self.batch_interval == 0:
            self._save_checkpoint(state, logger)

    def fit_end(self, state: State, logger: Logger) -> None:
        if state.timestamp.batch.value % self.batch_interval != 0:
            self._save_checkpoint(state, logger)

    def _save_checkpoint(self, state: State, logger: Logger) -> None:
        del logger
        filename = format_name_with_dist_and_time(self.filename_format_str, state.run_name, state.timestamp)
        save_dir = format_name_with_dist_and_time(self.save_dir_format_str, state.run_name, state.timestamp)
        dir_context_mgr = tempfile.TemporaryDirectory() if self.upload_to_object_store else contextlib.nullcontext(enter_result=save_dir)
        with dir_context_mgr as temp_save_dir:
            assert isinstance(temp_save_dir, str)
            save_path = str(Path(temp_save_dir) / Path(filename))
            dirname = os.path.dirname(save_path)
            if dirname:
                os.makedirs(dirname, exist_ok=True)
            state_dict = {'state': state.state_dict(), 'rng': reproducibility.get_rng_state()}
            state_dict['state'].pop('optimizers')
            state_dict['state'].pop('model')
            with fsdp_state_dict_type_context(state.model, state_dict_type='full'):
                state_dict['state']['model'] = state.model.state_dict()
            if self.keep_optimizers:
                optimizer = state.optimizers[0]
                state_dict['state']['optimizers'] = {type(optimizer).__qualname__: fsdp_get_optim_state_dict(state.model, optimizer, state_dict_type='full')}
            if dist.get_global_rank() == 0:
                torch.save(state_dict, save_path)
            if self.upload_to_object_store and self.remote_ud is not None and (dist.get_global_rank() == 0):
                remote_file_name = str(Path(save_dir) / Path(filename))
                self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(save_path), overwrite=self.overwrite)