# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """TensorFlow utility functions. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from copy import deepcopy import tensorflow as tf from tf_agents import specs from tf_agents.utils import common _tf_print_counts = dict() _tf_print_running_sums = dict() _tf_print_running_counts = dict() _tf_print_ids = 0 def get_contextual_env_base(env_base, begin_ops=None, end_ops=None): """Wrap env_base with additional tf ops.""" # pylint: disable=protected-access def init(self_, env_base): self_._env_base = env_base attribute_list = ["_render_mode", "_gym_env"] for attribute in attribute_list: if hasattr(env_base, attribute): setattr(self_, attribute, getattr(env_base, attribute)) if hasattr(env_base, "physics"): self_._physics = env_base.physics elif hasattr(env_base, "gym"): class Physics(object): def render(self, *args, **kwargs): return env_base.gym.render("rgb_array") physics = Physics() self_._physics = physics self_.physics = physics def set_sess(self_, sess): self_._sess = sess if hasattr(self_._env_base, "set_sess"): self_._env_base.set_sess(sess) def begin_episode(self_): self_._env_base.reset() if begin_ops is not None: self_._sess.run(begin_ops) def end_episode(self_): self_._env_base.reset() if end_ops is not None: self_._sess.run(end_ops) return type("ContextualEnvBase", (env_base.__class__,), dict( __init__=init, set_sess=set_sess, begin_episode=begin_episode, end_episode=end_episode, ))(env_base) # pylint: enable=protected-access def merge_specs(specs_): """Merge TensorSpecs. Args: specs_: List of TensorSpecs to be merged. Returns: a TensorSpec: a merged TensorSpec. """ shape = specs_[0].shape dtype = specs_[0].dtype name = specs_[0].name for spec in specs_[1:]: assert shape[1:] == spec.shape[1:], "incompatible shapes: %s, %s" % ( shape, spec.shape) assert dtype == spec.dtype, "incompatible dtypes: %s, %s" % ( dtype, spec.dtype) shape = merge_shapes((shape, spec.shape), axis=0) return specs.TensorSpec( shape=shape, dtype=dtype, name=name, ) def merge_shapes(shapes, axis=0): """Merge TensorShapes. Args: shapes: List of TensorShapes to be merged. axis: optional, the axis to merge shaped. Returns: a TensorShape: a merged TensorShape. """ assert len(shapes) > 1 dims = deepcopy(shapes[0].dims) for shape in shapes[1:]: assert shapes[0].ndims == shape.ndims dims[axis] += shape.dims[axis] return tf.TensorShape(dims=dims) def get_all_vars(ignore_scopes=None): """Get all tf variables in scope. Args: ignore_scopes: A list of scope names to ignore. Returns: A list of all tf variables in scope. """ all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) all_vars = [var for var in all_vars if ignore_scopes is None or not any(var.name.startswith(scope) for scope in ignore_scopes)] return all_vars def clip(tensor, range_=None): """Return a tf op which clips tensor according to range_. Args: tensor: A Tensor to be clipped. range_: None, or a tuple representing (minval, maxval) Returns: A clipped Tensor. """ if range_ is None: return tf.identity(tensor) elif isinstance(range_, (tuple, list)): assert len(range_) == 2 return tf.clip_by_value(tensor, range_[0], range_[1]) else: raise NotImplementedError("Unacceptable range input: %r" % range_) def clip_to_bounds(value, minimum, maximum): """Clips value to be between minimum and maximum. Args: value: (tensor) value to be clipped. minimum: (numpy float array) minimum value to clip to. maximum: (numpy float array) maximum value to clip to. Returns: clipped_value: (tensor) `value` clipped to between `minimum` and `maximum`. """ value = tf.minimum(value, maximum) return tf.maximum(value, minimum) clip_to_spec = common.clip_to_spec def _clip_to_spec(value, spec): """Clips value to a given bounded tensor spec. Args: value: (tensor) value to be clipped. spec: (BoundedTensorSpec) spec containing min. and max. values for clipping. Returns: clipped_value: (tensor) `value` clipped to be compatible with `spec`. """ return clip_to_bounds(value, spec.minimum, spec.maximum) join_scope = common.join_scope def _join_scope(parent_scope, child_scope): """Joins a parent and child scope using `/`, checking for empty/none. Args: parent_scope: (string) parent/prefix scope. child_scope: (string) child/suffix scope. Returns: joined scope: (string) parent and child scopes joined by /. """ if not parent_scope: return child_scope if not child_scope: return parent_scope return '/'.join([parent_scope, child_scope]) def assign_vars(vars_, values): """Returns the update ops for assigning a list of vars. Args: vars_: A list of variables. values: A list of tensors representing new values. Returns: A list of update ops for the variables. """ return [var.assign(value) for var, value in zip(vars_, values)] def identity_vars(vars_): """Return the identity ops for a list of tensors. Args: vars_: A list of tensors. Returns: A list of identity ops. """ return [tf.identity(var) for var in vars_] def tile(var, batch_size=1): """Return tiled tensor. Args: var: A tensor representing the state. batch_size: Batch size. Returns: A tensor with shape [batch_size,] + var.shape. """ batch_var = tf.tile( tf.expand_dims(var, 0), (batch_size,) + (1,) * var.get_shape().ndims) return batch_var def batch_list(vars_list): """Batch a list of variables. Args: vars_list: A list of tensor variables. Returns: A list of tensor variables with additional first dimension. """ return [tf.expand_dims(var, 0) for var in vars_list] def tf_print(op, tensors, message="", first_n=-1, name=None, sub_messages=None, print_freq=-1, include_count=True): """tf.Print, but to stdout.""" # TODO(shanegu): `name` is deprecated. Remove from the rest of codes. global _tf_print_ids _tf_print_ids += 1 name = _tf_print_ids _tf_print_counts[name] = 0 if print_freq > 0: _tf_print_running_sums[name] = [0 for _ in tensors] _tf_print_running_counts[name] = 0 def print_message(*xs): """print message fn.""" _tf_print_counts[name] += 1 if print_freq > 0: for i, x in enumerate(xs): _tf_print_running_sums[name][i] += x _tf_print_running_counts[name] += 1 if (print_freq <= 0 or _tf_print_running_counts[name] >= print_freq) and ( first_n < 0 or _tf_print_counts[name] <= first_n): for i, x in enumerate(xs): if print_freq > 0: del x x = _tf_print_running_sums[name][i]/_tf_print_running_counts[name] if sub_messages is None: sub_message = str(i) else: sub_message = sub_messages[i] log_message = "%s, %s" % (message, sub_message) if include_count: log_message += ", count=%d" % _tf_print_counts[name] tf.logging.info("[%s]: %s" % (log_message, x)) if print_freq > 0: for i, x in enumerate(xs): _tf_print_running_sums[name][i] = 0 _tf_print_running_counts[name] = 0 return xs[0] print_op = tf.py_func(print_message, tensors, tensors[0].dtype) with tf.control_dependencies([print_op]): op = tf.identity(op) return op periodically = common.periodically def _periodically(body, period, name='periodically'): """Periodically performs a tensorflow op.""" if period is None or period == 0: return tf.no_op() if period < 0: raise ValueError("period cannot be less than 0.") if period == 1: return body() with tf.variable_scope(None, default_name=name): counter = tf.get_variable( "counter", shape=[], dtype=tf.int64, trainable=False, initializer=tf.constant_initializer(period, dtype=tf.int64)) def _wrapped_body(): with tf.control_dependencies([body()]): return counter.assign(1) update = tf.cond( tf.equal(counter, period), _wrapped_body, lambda: counter.assign_add(1)) return update soft_variables_update = common.soft_variables_update