File size: 3,569 Bytes
be4d0c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt

def plot_loss(history, axis = None):
    """
    Parameters
    ----------
    
        history : 'tf.keras.callbacks.History' object
        axis : 'matplotlib.pyplot.axis' object
    """
    if axis is not None:
        axis.plot(history.epoch, history.history["loss"],
                  label = "Train loss", color = "#191970")
        axis.plot(history.epoch, history.history["val_loss"],
                  label = "Val loss", color = "#00CC33")
        axis.set_title("Loss")
        axis.legend()
    else:
        plt.plot(history.epoch, history.history["loss"],
                 label = "Train loss", color = "#191970")
        plt.plot(history.epoch, history.history["val_loss"],
                 label = "Val loss", color = "#00CC33")
        plt.title("Loss")
        plt.legend()

    
def plot_accuracy(history, axis = None):
    """
    Parameters
    ----------
    
        history : 'tf.keras.callbacks.History' object
        axis : 'matplotlib.pyplot.axis' object
    """
    if axis is not None:
        axis.plot(history.epoch, history.history["accuracy"],
                  label = "Train accuracy", color = "#191970")
        axis.plot(history.epoch, history.history["val_accuracy"],
                  label = "Val accuracy", color = "#00CC33")
        axis.set_ylim(0, 1.1)
        axis.set_title("Accuracy")
        axis.legend()
    else:
        plt.plot(history.epoch, history.history["accuracy"],
                 label = "Train accuracy", color = "#191970")
        plt.plot(history.epoch, history.history["val_accuracy"],
                 label = "Val accuracy", color = "#00CC33")
        plt.title("Accuracy")
        plt.ylim(0, 1.1)
        plt.legend()
    
    
def keras_model_memory_usage_in_bytes(model, *, batch_size: int):
    """
    Return the estimated memory usage of a given Keras model in bytes.
    This includes the model weights and layers, but excludes the dataset.

    The model shapes are multipled by the batch size, but the weights are not.

    Parameters
    ----------
        model: A Keras model.
        batch_size: The batch size you intend to run the model with. If you
            have already specified the batch size in the model itself, then
            pass `1` as the argument here.
    
    Returns
    -------
        An estimate of the Keras model's memory usage in bytes.

    """
    default_dtype = tf.keras.backend.floatx()
    shapes_mem_count = 0
    internal_model_mem_count = 0
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):
            internal_model_mem_count += keras_model_memory_usage_in_bytes(layer,
                                                                          batch_size = batch_size)
        single_layer_mem = tf.as_dtype(layer.dtype or default_dtype).size
        out_shape = layer.output_shape
        if isinstance(out_shape, list):
            out_shape = out_shape[0]
        for s in out_shape:
            if s is None:
                continue
            single_layer_mem *= s
        shapes_mem_count += single_layer_mem

    trainable_count = sum([tf.keras.backend.count_params(p)
                           for p in model.trainable_weights])
    non_trainable_count = sum([tf.keras.backend.count_params(p)
                               for p in model.non_trainable_weights])

    total_memory = (batch_size * shapes_mem_count + internal_model_mem_count\
                    + trainable_count + non_trainable_count)
    
    return total_memory