File size: 1,942 Bytes
5231633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import torch
import json
from pathlib import Path
import safetensors
import wandb


def create_folder_if_necessary(path):
    path = "/".join(path.split("/")[:-1])
    Path(path).mkdir(parents=True, exist_ok=True)


def safe_save(ckpt, path):
    try:
        os.remove(f"{path}.bak")
    except OSError:
        pass
    try:
        os.rename(path, f"{path}.bak")
    except OSError:
        pass
    if path.endswith(".pt") or path.endswith(".ckpt"):
        torch.save(ckpt, path)
    elif path.endswith(".json"):
        with open(path, "w", encoding="utf-8") as f:
            json.dump(ckpt, f, indent=4)
    elif path.endswith(".safetensors"):
        safetensors.torch.save_file(ckpt, path)
    else:
        raise ValueError(f"File extension not supported: {path}")


def load_or_fail(path, wandb_run_id=None):
    accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
    try:
        assert any(
            [path.endswith(ext) for ext in accepted_extensions]
        ), f"Automatic loading not supported for this extension: {path}"
        if not os.path.exists(path):
            checkpoint = None
        elif path.endswith(".pt") or path.endswith(".ckpt"):
            checkpoint = torch.load(path, map_location="cpu")
        elif path.endswith(".json"):
            with open(path, "r", encoding="utf-8") as f:
                checkpoint = json.load(f)
        elif path.endswith(".safetensors"):
            checkpoint = {}
            with safetensors.safe_open(path, framework="pt", device="cpu") as f:
                for key in f.keys():
                    checkpoint[key] = f.get_tensor(key)
        return checkpoint
    except Exception as e:
        if wandb_run_id is not None:
            wandb.alert(
                title=f"Corrupt checkpoint for run {wandb_run_id}",
                text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
            )
        raise e