HugoVoxx's picture
Upload 20 files
15bcbe6 verified
raw
history blame
3.74 kB
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper routines for recording various training metrics."""
from typing import Any
import jax.numpy as jnp
Array = Any
def compute_accuracy_sum(logits, targets, valid_loss_mask=None):
"""Compute accuracy for logits and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
valid_loss_mask: None or array of shape bool[batch, length]
Returns:
The number of correct tokens in the output.
"""
if logits.shape[:-1] != targets.shape:
raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
logits.shape, targets.shape)
if valid_loss_mask is not None and valid_loss_mask.shape != targets.shape:
raise ValueError("Incorrect shapes. Got shape %s targets and %s mask" %
targets.shape, valid_loss_mask.shape)
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), targets)
if valid_loss_mask is not None:
accuracy = jnp.logical_and(accuracy, valid_loss_mask)
return jnp.sum(accuracy) # Sum of the number of True values.
def reshape_image(image):
"""Reshape image to something that tensorboard recognizes.
Args:
image: Array of shape [xsize, size] or [num_images, xsize, ysize]
Returns:
Array of shape [num_images, xsize, ysize, 1]
"""
# Reshape to [num_images, xdim, ydim, rgb] for tensorboard.
sh = image.shape
if image.ndim == 2:
return jnp.reshape(image, [1, sh[0], sh[1], 1]).astype(jnp.float32)
elif image.ndim == 3:
return jnp.reshape(image, [sh[0], sh[1], sh[2], 1]).astype(jnp.float32)
else:
return None # Not an image.
def normalize_image(images: Array, as_group: bool = False) -> Array:
"""Rescale the values in images to between 0.0 and 1.0.
Args:
images: Array of size [batch_size, xsize, ysize]
as_group: Scale all images in the batch by the same amount if True.
Returns:
A rescaled image of the same shape.
"""
images = images.astype(jnp.float32) # Return images as float32.
if as_group:
# Normalize the batch of images as a group.
min_img = jnp.min(images)
max_img = jnp.max(images)
else:
# Normalize each image in the batch individually.
min_img = jnp.min(images, axis=(-2, -1), keepdims=True)
max_img = jnp.max(images, axis=(-2, -1), keepdims=True)
norm_image = (images - min_img) / (max_img - min_img + 1e-6)
return jnp.where(jnp.isfinite(norm_image), norm_image, 0.0)
def overlay_images(image1: Array, image2: Array) -> Array:
"""Place image1 on top of image2, broadcasting image2 if necessary.
Args:
image1: array of shape [num_images, xsize, ysize]
image2: array of shape [num_images, xsize, ysize]
Returns:
A combined image.
"""
assert image1.ndim == 3 # (num_images, xsize, ysize)
assert image2.ndim == 3
image2 = jnp.broadcast_to(image2, image1.shape)
return jnp.concatenate([image1, image2], axis=1)
def make_histograms(viz_dicts):
"""Generate image histograms."""
hist_dict = {}
for (i, viz_dict) in enumerate(viz_dicts):
for (k, images) in viz_dict.items():
hist_dict["h_" + k + "_" + str(i)] = images
return hist_dict