Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import json | |
from pathlib import Path | |
import safetensors | |
import wandb | |
def create_folder_if_necessary(path): | |
path = "/".join(path.split("/")[:-1]) | |
Path(path).mkdir(parents=True, exist_ok=True) | |
def safe_save(ckpt, path): | |
try: | |
os.remove(f"{path}.bak") | |
except OSError: | |
pass | |
try: | |
os.rename(path, f"{path}.bak") | |
except OSError: | |
pass | |
if path.endswith(".pt") or path.endswith(".ckpt"): | |
torch.save(ckpt, path) | |
elif path.endswith(".json"): | |
with open(path, "w", encoding="utf-8") as f: | |
json.dump(ckpt, f, indent=4) | |
elif path.endswith(".safetensors"): | |
safetensors.torch.save_file(ckpt, path) | |
else: | |
raise ValueError(f"File extension not supported: {path}") | |
def load_or_fail(path, wandb_run_id=None): | |
accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] | |
try: | |
assert any( | |
[path.endswith(ext) for ext in accepted_extensions] | |
), f"Automatic loading not supported for this extension: {path}" | |
if not os.path.exists(path): | |
checkpoint = None | |
elif path.endswith(".pt") or path.endswith(".ckpt"): | |
checkpoint = torch.load(path, map_location="cpu") | |
elif path.endswith(".json"): | |
with open(path, "r", encoding="utf-8") as f: | |
checkpoint = json.load(f) | |
elif path.endswith(".safetensors"): | |
checkpoint = {} | |
with safetensors.safe_open(path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
checkpoint[key] = f.get_tensor(key) | |
return checkpoint | |
except Exception as e: | |
if wandb_run_id is not None: | |
wandb.alert( | |
title=f"Corrupt checkpoint for run {wandb_run_id}", | |
text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", | |
) | |
raise e | |