File size: 3,506 Bytes
824afbf
 
 
 
 
 
 
 
 
 
4f90f1b
824afbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f90f1b
824afbf
4f90f1b
 
 
 
824afbf
 
 
 
 
 
 
 
4f90f1b
 
 
 
 
824afbf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch as T
import re
from tqdm import tqdm
from datetime import timedelta

import requests
import hashlib

from io import BytesIO  
from huggingface_hub import hf_hub_download

def rank0():
    rank = os.environ.get('RANK')
    if rank is None or rank == '0':
        return True
    else:
        return False
    
def local0():
    local_rank = os.environ.get('LOCAL_RANK')
    if local_rank is None or local_rank == '0':
        return True
    else:
        return False
class tqdm0(tqdm):
    def __init__(self, *args, **kwargs):
        total = kwargs.get('total', None)
        if total is None and len(args) > 0:
            try:
                total = len(args[0])
            except TypeError:
                pass
        if total is not None:
            kwargs['miniters'] = max(1, total // 20)
        super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
    
def print0(*args, **kwargs):
    if rank0():
        print(*args, **kwargs)

_PRINTED_IDS = set()

def printonce(*args, id=None, **kwargs):
    if id is None:
        id = ' '.join(map(str, args))
    
    if id not in _PRINTED_IDS:
        print(*args, **kwargs)
        _PRINTED_IDS.add(id)

def print0once(*args, **kwargs):
    if rank0(): 
        printonce(*args, **kwargs)

def init_dist():
    if T.distributed.is_initialized():
        print0('Distributed already initialized')
        rank = T.distributed.get_rank()
        local_rank = int(os.environ.get('LOCAL_RANK', 0))
        world_size = T.distributed.get_world_size()
    else:
        try:
            rank = int(os.environ['RANK'])
            local_rank = int(os.environ['LOCAL_RANK'])
            world_size = int(os.environ['WORLD_SIZE'])
            device = f'cuda:{local_rank}'
            T.cuda.set_device(device)
            T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
            print(f'Rank {rank} of {world_size}.')
        except Exception as e:
            print0once(f'Not initializing distributed env: {e}')
            rank = 0
            local_rank = 0
            world_size = 1
    return rank, local_rank, world_size

def load_ckpt(load_from_location, expected_hash=None):
    os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
    if local0():
        repo_id = "si-pbc/hertz-dev"
        print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
        save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
        print0(f'Downloaded checkpoint to {save_path}')
        if expected_hash is not None:
            with open(save_path, 'rb') as f:
                file_hash = hashlib.md5(f.read()).hexdigest()
            if file_hash != expected_hash:
                print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
                os.remove(save_path)
                return load_ckpt(load_from_location, expected_hash)
    if T.distributed.is_initialized():
        save_path = [save_path]
        T.distributed.broadcast_object_list(save_path, src=0)
        save_path = save_path[0]
    loaded = T.load(save_path, weights_only=False, map_location='cpu')
    print0(f'Loaded checkpoint from {save_path}')
    return loaded