# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper routines for recording various training metrics.""" from typing import Any import jax.numpy as jnp Array = Any def compute_accuracy_sum(logits, targets, valid_loss_mask=None): """Compute accuracy for logits and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. valid_loss_mask: None or array of shape bool[batch, length] Returns: The number of correct tokens in the output. """ if logits.shape[:-1] != targets.shape: raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % logits.shape, targets.shape) if valid_loss_mask is not None and valid_loss_mask.shape != targets.shape: raise ValueError("Incorrect shapes. Got shape %s targets and %s mask" % targets.shape, valid_loss_mask.shape) accuracy = jnp.equal(jnp.argmax(logits, axis=-1), targets) if valid_loss_mask is not None: accuracy = jnp.logical_and(accuracy, valid_loss_mask) return jnp.sum(accuracy) # Sum of the number of True values. def reshape_image(image): """Reshape image to something that tensorboard recognizes. Args: image: Array of shape [xsize, size] or [num_images, xsize, ysize] Returns: Array of shape [num_images, xsize, ysize, 1] """ # Reshape to [num_images, xdim, ydim, rgb] for tensorboard. sh = image.shape if image.ndim == 2: return jnp.reshape(image, [1, sh[0], sh[1], 1]).astype(jnp.float32) elif image.ndim == 3: return jnp.reshape(image, [sh[0], sh[1], sh[2], 1]).astype(jnp.float32) else: return None # Not an image. def normalize_image(images: Array, as_group: bool = False) -> Array: """Rescale the values in images to between 0.0 and 1.0. Args: images: Array of size [batch_size, xsize, ysize] as_group: Scale all images in the batch by the same amount if True. Returns: A rescaled image of the same shape. """ images = images.astype(jnp.float32) # Return images as float32. if as_group: # Normalize the batch of images as a group. min_img = jnp.min(images) max_img = jnp.max(images) else: # Normalize each image in the batch individually. min_img = jnp.min(images, axis=(-2, -1), keepdims=True) max_img = jnp.max(images, axis=(-2, -1), keepdims=True) norm_image = (images - min_img) / (max_img - min_img + 1e-6) return jnp.where(jnp.isfinite(norm_image), norm_image, 0.0) def overlay_images(image1: Array, image2: Array) -> Array: """Place image1 on top of image2, broadcasting image2 if necessary. Args: image1: array of shape [num_images, xsize, ysize] image2: array of shape [num_images, xsize, ysize] Returns: A combined image. """ assert image1.ndim == 3 # (num_images, xsize, ysize) assert image2.ndim == 3 image2 = jnp.broadcast_to(image2, image1.shape) return jnp.concatenate([image1, image2], axis=1) def make_histograms(viz_dicts): """Generate image histograms.""" hist_dict = {} for (i, viz_dict) in enumerate(viz_dicts): for (k, images) in viz_dict.items(): hist_dict["h_" + k + "_" + str(i)] = images return hist_dict