File size: 3,646 Bytes
a166479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   summary.py
@Time    :   2022/10/15 23:38:13
@Author  :   BQH 
@Version :   1.0
@Contact :   raogx.vip@hotmail.com
@License :   (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
@Desc    :   运行时日志文件
'''

# here put the import lib

import os
import sys
import torch
import logging
from datetime import datetime

# return a fake summarywriter if tensorbaordX is not installed

try:
    from tensorboardX import SummaryWriter
except ImportError:
    class SummaryWriter:
        def __init__(self, log_dir=None, comment='', **kwargs):
            print('\nunable to import tensorboardX, log will be recorded by pytorch!\n')
            self.log_dir = log_dir if log_dir is not None else './logs'
            os.makedirs('./logs', exist_ok=True)
            self.logs = {'comment': comment}
            return

        def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
            if tag in self.logs:
                self.logs[tag].append((scalar_value, global_step, walltime))
            else:
                self.logs[tag] = [(scalar_value, global_step, walltime)]
            return

        def close(self):
            timestamp = str(datetime.now()).replace(' ', '_').replace(':', '_')
            torch.save(self.logs, os.path.join(self.log_dir, 'log_%s.pickle' % timestamp))
            return


class EmptySummaryWriter:
    def __init__(self, **kwargs):
        pass

    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        pass

    def close(self):
        pass


def create_summary(distributed_rank=0, **kwargs):
    if distributed_rank > 0:
        return EmptySummaryWriter(**kwargs)
    else:
        return SummaryWriter(**kwargs)


def create_logger(distributed_rank=0, save_dir=None):
    logger = logging.getLogger('logger')
    logger.setLevel(logging.DEBUG)

    filename = "log_%s.txt" % (datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))

    # don't log results for the non-master process
    if distributed_rank > 0:
        return logger
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    # formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
    formatter = logging.Formatter("%(message)s [%(asctime)s]")
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    if save_dir is not None:
        fh = logging.FileHandler(os.path.join(save_dir, filename))
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    return logger


class Saver:
    def __init__(self, distributed_rank, save_dir):
        self.distributed_rank = distributed_rank
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)
        return

    def save(self, obj, save_name):
        if self.distributed_rank == 0:
            torch.save(obj, os.path.join(self.save_dir, save_name + '.t7'))
            return 'checkpoint saved in %s !' % os.path.join(self.save_dir, save_name)
        else:
            return ''


def create_saver(distributed_rank, save_dir):
    return Saver(distributed_rank, save_dir)


class DisablePrint:
    def __init__(self, local_rank=0):
        self.local_rank = local_rank

    def __enter__(self):
        if self.local_rank != 0:
            self._original_stdout = sys.stdout
            sys.stdout = open(os.devnull, 'w')
        else:
            pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.local_rank != 0:
            sys.stdout.close()
            sys.stdout = self._original_stdout
        else:
            pass