diff --git "a/symbolic_shape_infer.py" "b/symbolic_shape_infer.py" new file mode 100644--- /dev/null +++ "b/symbolic_shape_infer.py" @@ -0,0 +1,2991 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# -*- coding: UTF-8 -*- +import argparse +import logging + +import numpy as np +import onnx +import sympy +from onnx import helper, numpy_helper, shape_inference +from packaging import version + +assert version.parse(onnx.__version__) >= version.parse("1.8.0") + +logger = logging.getLogger(__name__) + + +def get_attribute(node, attr_name, default_value=None): + found = [attr for attr in node.attribute if attr.name == attr_name] + if found: + return helper.get_attribute_value(found[0]) + return default_value + + +def get_dim_from_proto(dim): + return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None # noqa: E721 + + +def is_sequence(type_proto): + cls_type = type_proto.WhichOneof("value") + assert cls_type in ["tensor_type", "sequence_type"] + return cls_type == "sequence_type" + + +def get_shape_from_type_proto(type_proto): + assert not is_sequence(type_proto) + if type_proto.tensor_type.HasField("shape"): + return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] + else: + return None # note no shape is different from shape without dim (scalar) + + +def get_elem_type_from_type_proto(type_proto): + if is_sequence(type_proto): + return type_proto.sequence_type.elem_type.tensor_type.elem_type + else: + return type_proto.tensor_type.elem_type + + +def get_shape_from_value_info(vi): + cls_type = vi.type.WhichOneof("value") + if cls_type is None: + return None + if is_sequence(vi.type): + if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type": + return get_shape_from_type_proto(vi.type.sequence_type.elem_type) + else: + return None + else: + return get_shape_from_type_proto(vi.type) + + +def make_named_value_info(name): + vi = onnx.ValueInfoProto() + vi.name = name + return vi + + +def get_shape_from_sympy_shape(sympy_shape): + return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] + + +def is_literal(dim): + return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) + + +def handle_negative_axis(axis, rank): + assert axis < rank and axis >= -rank + return axis if axis >= 0 else rank + axis + + +def get_opset(mp, domain=None): + domain = domain or ["", "onnx", "ai.onnx"] + if type(domain) != list: # noqa: E721 + domain = [domain] + for opset in mp.opset_import: + if opset.domain in domain: + return opset.version + + return None + + +def as_scalar(x): + if type(x) == list: # noqa: E721 + assert len(x) == 1 + return x[0] + elif type(x) == np.ndarray: + return x.item() + else: + return x + + +def as_list(x, keep_none): + if type(x) == list: # noqa: E721 + return x + elif type(x) == np.ndarray: + return list(x) + elif keep_none and x is None: + return None + else: + return [x] + + +def sympy_reduce_product(x): + if type(x) == list: # noqa: E721 + value = sympy.Integer(1) + for v in x: + value = value * v + else: + value = x + return value + + +class SymbolicShapeInference: + def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): + self.dispatcher_ = { + "Add": self._infer_symbolic_compute_ops, + "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, + "AveragePool": self._infer_Pool, + "BatchNormalization": self._infer_BatchNormalization, + "Cast": self._infer_Cast, + "CategoryMapper": self._infer_CategoryMapper, + "Compress": self._infer_Compress, + "Concat": self._infer_Concat, + "ConcatFromSequence": self._infer_ConcatFromSequence, + "Constant": self._infer_Constant, + "ConstantOfShape": self._infer_ConstantOfShape, + "Conv": self._infer_Conv, + "CumSum": self._pass_on_shape_and_type, + "Div": self._infer_symbolic_compute_ops, + "Einsum": self._infer_Einsum, + "Expand": self._infer_Expand, + "Equal": self._infer_symbolic_compute_ops, + "Floor": self._infer_symbolic_compute_ops, + "Gather": self._infer_Gather, + "GatherElements": self._infer_GatherElements, + "GatherND": self._infer_GatherND, + "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, + "If": self._infer_If, + "Loop": self._infer_Loop, + "MatMul": self._infer_MatMul, + "MatMulInteger16": self._infer_MatMulInteger, + "MaxPool": self._infer_Pool, + "Max": self._infer_symbolic_compute_ops, + "MemcpyFromHost": self._pass_on_shape_and_type, + "MemcpyToHost": self._pass_on_shape_and_type, + "Min": self._infer_symbolic_compute_ops, + "MoE": self._pass_on_shape_and_type, + "Mul": self._infer_symbolic_compute_ops, + "NonMaxSuppression": self._infer_NonMaxSuppression, + "NonZero": self._infer_NonZero, + "OneHot": self._infer_OneHot, + "Pad": self._infer_Pad, + "Range": self._infer_Range, + "Reciprocal": self._pass_on_shape_and_type, + "ReduceSum": self._infer_ReduceSum, + "ReduceProd": self._infer_ReduceProd, + "Reshape": self._infer_Reshape, + "Resize": self._infer_Resize, + "Round": self._pass_on_shape_and_type, + "Scan": self._infer_Scan, + "ScatterElements": self._infer_ScatterElements, + "SequenceAt": self._infer_SequenceAt, + "SequenceInsert": self._infer_SequenceInsert, + "Shape": self._infer_Shape, + "Size": self._infer_Size, + "Slice": self._infer_Slice, + "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss, + "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "Split": self._infer_Split, + "SplitToSequence": self._infer_SplitToSequence, + "Squeeze": self._infer_Squeeze, + "Sub": self._infer_symbolic_compute_ops, + "Tile": self._infer_Tile, + "TopK": self._infer_TopK, + "Transpose": self._infer_Transpose, + "Unsqueeze": self._infer_Unsqueeze, + "Where": self._infer_symbolic_compute_ops, + "ZipMap": self._infer_ZipMap, + "Neg": self._infer_symbolic_compute_ops, + # contrib ops: + "Attention": self._infer_Attention, + "BiasAdd": self._infer_BiasAdd, + "BiasGelu": self._infer_BiasGelu, + "BiasSplitGelu": self._infer_BiasSplitGelu, + "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, + "DequantizeLinear": self._infer_DequantizeLinear, + "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, + "FastGelu": self._infer_FastGelu, + "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, + "Gelu": self._infer_Gelu, + "GemmFastGelu": self._infer_GemmFastGelu, + "GemmFloat8": self._infer_GemmFloat8, + "GroupNorm": self._infer_GroupNorm, + "GroupQueryAttention": self._infer_GroupQueryAttention, + "SkipGroupNorm": self._infer_SkipGroupNorm, + "LayerNormalization": self._infer_LayerNormalization, + "LongformerAttention": self._infer_LongformerAttention, + "MultiHeadAttention": self._infer_MultiHeadAttention, + "NhwcConv": self._infer_NhwcConv, + "PackedAttention": self._infer_PackedAttention, + "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, + "PagedAttention": self._infer_PagedAttention, + "PythonOp": self._infer_PythonOp, + "QuantizeLinear": self._infer_QuantizeLinear, + "QuickGelu": self._infer_FastGelu, + "RelativePositionBias": self._infer_RelativePositionBias, + "RemovePadding": self._infer_RemovePadding, + "RestorePadding": self._infer_RestorePadding, + "RotaryEmbedding": self._infer_RotaryEmbedding, + "SimplifiedLayerNormalization": self._infer_LayerNormalization, + "SkipLayerNormalization": self._infer_SkipLayerNormalization, + "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + } + self.aten_op_dispatcher_ = { + "embedding": self._infer_Gather, + "bitwise_or": self._infer_aten_bitwise_or, + "diagonal": self._infer_aten_diagonal, + "max_pool2d_with_indices": self._infer_aten_pool2d, + "max": self._infer_aten_minmax, + "min": self._infer_aten_minmax, + "multinomial": self._infer_aten_multinomial, + "unfold": self._infer_aten_unfold, + "argmax": self._infer_aten_argmax, + "avg_pool2d": self._infer_aten_pool2d, + "_adaptive_avg_pool2d": self._infer_aten_pool2d, + "numpy_T": self._infer_Transpose, + "native_group_norm": self._infer_aten_group_norm, + "upsample_nearest1d": self._infer_aten_upsample, + "upsample_nearest2d": self._infer_aten_upsample, + "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bicubic2d": self._infer_aten_upsample, + } + self.run_ = True + self.suggested_merge_ = {} + self.symbolic_dims_ = {} + self.input_symbols_ = {} + self.auto_merge_ = auto_merge + self.guess_output_rank_ = guess_output_rank + self.verbose_ = verbose + self.int_max_ = int_max + self.subgraph_id_ = 0 + self.prefix_ = prefix + + def _add_suggested_merge(self, symbols, apply=False): + assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols]) # noqa: E721 + symbols = set(symbols) + for k, v in self.suggested_merge_.items(): + if k in symbols: + symbols.remove(k) + symbols.add(v) + map_to = None + # if there is literal, map to it first + for s in symbols: + if is_literal(s): + map_to = s + break + # when no literals, map to input symbolic dims, then existing symbolic dims + if map_to is None: + for s in symbols: + if s in self.input_symbols_: + map_to = s + break + if map_to is None: + for s in symbols: + if type(self.symbolic_dims_[s]) == sympy.Symbol: + map_to = s + break + # when nothing to map to, use the shorter one + if map_to is None: + if self.verbose_ > 0: + logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols))) + symbols_list = list(symbols) + lens = [len(s) for s in symbols_list] + map_to = symbols_list[lens.index(min(lens))] + symbols.remove(map_to) + + for s in symbols: + if s == map_to: + continue + if is_literal(map_to) and is_literal(s): + assert int(map_to) == int(s) + self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to + for k, v in self.suggested_merge_.items(): + if v == s: + self.suggested_merge_[k] = map_to + if apply and self.auto_merge_: + self._apply_suggested_merge() + + def _apply_suggested_merge(self, graph_input_only=False): + if not self.suggested_merge_: + return + for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)): + for d in i.type.tensor_type.shape.dim: + if d.dim_param in self.suggested_merge_: + v = self.suggested_merge_[d.dim_param] + if is_literal(v): + d.dim_value = int(v) + else: + d.dim_param = v + + def _preprocess(self, in_mp): + self.out_mp_ = onnx.ModelProto() + self.out_mp_.CopyFrom(in_mp) + self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)} + self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer} + self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)} + self.known_vi_.update( + { + i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)) + for i in self.out_mp_.graph.initializer + } + ) + + def _merge_symbols(self, dims): + if not all([type(d) == str for d in dims]): # noqa: E721 + if self.auto_merge_: + unique_dims = list(set(dims)) + is_int = [is_literal(d) for d in unique_dims] + assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong + if sum(is_int) == 1: + int_dim = is_int.index(1) + if self.verbose_ > 0: + logger.debug( + "dim {} has been merged with value {}".format( + unique_dims[:int_dim] + unique_dims[int_dim + 1 :], + unique_dims[int_dim], + ) + ) + self._check_merged_dims(unique_dims, allow_broadcast=False) + return unique_dims[int_dim] + else: + if self.verbose_ > 0: + logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}") + return dims[0] + else: + return None + if all([d == dims[0] for d in dims]): + return dims[0] + merged = [self.suggested_merge_.get(d, d) for d in dims] + if all([d == merged[0] for d in merged]): + assert merged[0] in self.symbolic_dims_ + return merged[0] + else: + return None + + # broadcast from right to left, and merge symbolic dims if needed + def _broadcast_shapes(self, shape1, shape2): + new_shape = [] + rank1 = len(shape1) + rank2 = len(shape2) + new_rank = max(rank1, rank2) + for i in range(new_rank): + dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 + dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 + if dim1 == 1 or dim1 == dim2: + new_dim = dim2 + elif dim2 == 1: + new_dim = dim1 + else: + new_dim = self._merge_symbols([dim1, dim2]) + if not new_dim: + # warning about unsupported broadcast when not auto merge + # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 + # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' + if self.auto_merge_: + self._add_suggested_merge([dim1, dim2], apply=True) + else: + logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) + new_shape = [new_dim, *new_shape] + return new_shape + + def _get_shape(self, node, idx): + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + else: + assert name in self.initializers_ + return list(self.initializers_[name].dims) + + def _try_get_shape(self, node, idx): + if idx > len(node.input) - 1: + return None + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + if name in self.initializers_: + return list(self.initializers_[name].dims) + return None + + def _get_shape_rank(self, node, idx): + return len(self._get_shape(node, idx)) + + def _get_sympy_shape(self, node, idx): + sympy_shape = [] + for d in self._get_shape(node, idx): + if type(d) == str: # noqa: E721 + sympy_shape.append( + self.symbolic_dims_[d] + if d in self.symbolic_dims_ + else sympy.Symbol(d, integer=True, nonnegative=True) + ) + else: + assert None is not d + sympy_shape.append(d) + return sympy_shape + + def _get_value(self, node, idx): + name = node.input[idx] + assert name in self.sympy_data_ or name in self.initializers_ + return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name]) + + def _try_get_value(self, node, idx): + if idx >= len(node.input): + return None + name = node.input[idx] + if name in self.sympy_data_ or name in self.initializers_: + return self._get_value(node, idx) + return None + + def _update_computed_dims(self, new_sympy_shape): + for i, new_dim in enumerate(new_sympy_shape): + if not is_literal(new_dim) and type(new_dim) != str: # noqa: E721 + str_dim = str(new_dim) + if str_dim in self.suggested_merge_: + if is_literal(self.suggested_merge_[str_dim]): + continue # no need to create dim for literals + new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]] + else: + # add new_dim if it's a computational expression + if str(new_dim) not in self.symbolic_dims_: + self.symbolic_dims_[str(new_dim)] = new_dim + + def _onnx_infer_single_node(self, node): + # skip onnx shape inference for some ops, as they are handled in _infer_* + skip_infer = node.op_type in [ + "If", + "Loop", + "Scan", + "SplitToSequence", + "ZipMap", # contrib ops + "Attention", + "BiasGelu", + "EmbedLayerNormalization", + "FastGelu", + "Gelu", + "GemmFastGelu", + "LayerNormalization", + "LongformerAttention", + "DequantizeLinear", + "QuantizeLinear", + "RelativePositionBias", + "RemovePadding", + "RestorePadding", + "SimplifiedLayerNormalization", + "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", + "PackedAttention", + "PagedAttention", + "PythonOp", + "MultiHeadAttention", + "GroupNorm", + "GroupQueryAttention", + "SkipGroupNorm", + "BiasSplitGelu", + "BiasAdd", + "NhwcConv", + "QuickGelu", + "RotaryEmbedding", + ] + + if not skip_infer: + # Only pass initializers that satisfy the following condition: + # (1) Operator need value of some input for shape inference. + # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output. + # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. + # (3) The initializer is not in graph input. The means the node input is "constant" in inference. + initializers = [] + if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]: + initializers = [ + self.initializers_[name] + for name in node.input + if (name in self.initializers_ and name not in self.graph_inputs_) + ] + + # run single node inference with self.known_vi_ shapes + tmp_graph = helper.make_graph( + [node], + "tmp", + [self.known_vi_[i] for i in node.input if i], + [make_named_value_info(i) for i in node.output], + initializers, + ) + + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) + + for i_o in range(len(node.output)): + o = node.output[i_o] + if o: # skip optional output + vi = self.out_mp_.graph.value_info.add() + if not skip_infer: + vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) + else: + vi.name = o + self.known_vi_[o] = vi + + def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): + if self.verbose_ > 2: + logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}") + # node inputs are not passed directly to the subgraph + # it's up to the node dispatcher to prepare subgraph input + # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape + # besides, inputs in subgraph could shadow implicit inputs + subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} + subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs} + tmp_graph = helper.make_graph( + list(subgraph.node), + "tmp", + list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], + [make_named_value_info(i.name) for i in subgraph.output], + ) + tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input]) + tmp_graph.initializer.extend(subgraph.initializer) + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + symbolic_shape_inference = SymbolicShapeInference( + self.int_max_, + self.auto_merge_, + self.guess_output_rank_, + self.verbose_, + prefix=self.prefix_ + "_" + str(self.subgraph_id_), + ) + if inc_subgraph_id: + self.subgraph_id_ += 1 + + symbolic_shape_inference._preprocess(self.tmp_mp_) + symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() + while symbolic_shape_inference.run_: + symbolic_shape_inference._infer_impl(self.sympy_data_.copy()) + symbolic_shape_inference._update_output_from_vi() + if use_node_input: + # if subgraph uses node input, it needs to update to merged dims + subgraph.ClearField("input") + subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)]) + subgraph.ClearField("output") + subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) + subgraph.ClearField("value_info") + subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info) + subgraph.ClearField("node") + subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) + # for new symbolic dims from subgraph output, add to main graph symbolic dims + subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output] + subgraph_new_symbolic_dims = { + d for s in subgraph_shapes if s for d in s if type(d) == str and d not in self.symbolic_dims_ # noqa: E721 + } + new_dims = {} + for d in subgraph_new_symbolic_dims: + assert d in symbolic_shape_inference.symbolic_dims_ + new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] + self.symbolic_dims_.update(new_dims) + return symbolic_shape_inference + + def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False): + def int_or_float(value, allow_float_values): + # If casting into int has precision loss: keep float output + if allow_float_values and value % 1 != 0: + return value + return int(value) + + values = [self._try_get_value(node, i) for i in range(len(node.input))] + if all([v is not None for v in values]): + # some shape compute is in floating point, cast to int for sympy + for i, v in enumerate(values): + if type(v) != np.ndarray: + continue + if len(v.shape) > 1: + new_v = None # ignore value for rank > 1 + elif len(v.shape) == 0: + new_v = int_or_float(v.item(), allow_float_values) + else: + assert len(v.shape) == 1 + new_v = [int_or_float(vv, allow_float_values) for vv in v] + values[i] = new_v + values_len = [len(v) if isinstance(v, list) else 0 for v in values] + max_len = max(values_len) + if max_len >= 1 and broadcast: + # broadcast + for i, v in enumerate(values): + if v is None: + continue # don't broadcast if value is unknown + if isinstance(v, list): + if len(v) < max_len: + values[i] = v * max_len + else: + assert len(v) == max_len + else: + values[i] = [v] * max_len + return values + + def _compute_on_sympy_data(self, node, op_func): + assert len(node.output) == 1 + + # Before mul & div operations + # cast inputs into interger might lose decimal part and reduce precision + # keep them as float, finish the operation, then cast the result into integer + if node.op_type in ["Mul", "Div"]: + values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True) + else: + values = self._get_int_or_float_values(node, broadcast=True) + + if all([v is not None for v in values]): + is_list = [isinstance(v, list) for v in values] + as_list = any(is_list) + if as_list: + self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)] + else: + self.sympy_data_[node.output[0]] = op_func(values) + + def _pass_on_sympy_data(self, node): + assert len(node.input) == 1 or node.op_type in [ + "Reshape", + "Unsqueeze", + "Squeeze", + ] + self._compute_on_sympy_data(node, lambda x: x[0]) + + def _pass_on_shape_and_type(self, node): + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type), + self._get_shape(node, 0), + ) + ) + + def _new_symbolic_dim(self, prefix, dim): + new_dim = f"{prefix}_d{dim}" + if new_dim in self.suggested_merge_: + v = self.suggested_merge_[new_dim] + new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v + else: + new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True) + self.symbolic_dims_[new_dim] = new_symbolic_dim + return new_symbolic_dim + + def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): + return self._new_symbolic_dim( + "{}{}_{}_o{}_".format( + node.op_type, + self.prefix_, + list(self.out_mp_.graph.node).index(node), + out_idx, + ), + dim, + ) + + def _new_symbolic_shape(self, rank, node, out_idx=0): + return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] + + def _compute_conv_pool_shape(self, node, channels_last=False): + sympy_shape = self._get_sympy_shape(node, 0) + if len(node.input) > 1: + W_shape = self._get_sympy_shape(node, 1) # noqa: N806 + rank = len(W_shape) - 2 # number of spatial axes + kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:] + sympy_shape[3 if channels_last else 1] = W_shape[0] + else: + W_shape = None # noqa: N806 + kernel_shape = get_attribute(node, "kernel_shape") + rank = len(kernel_shape) + + assert len(sympy_shape) == rank + 2 + + # only need to symbolic shape inference if input has symbolic dims in spatial axes + spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] + is_symbolic_dims = [not is_literal(i) for i in spatial_shape] + + if not any(is_symbolic_dims): + shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) + if len(shape) > 0: + assert len(sympy_shape) == len(shape) + if channels_last: + sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]] + else: + sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] + return sympy_shape + + dilations = get_attribute(node, "dilations", [1] * rank) + strides = get_attribute(node, "strides", [1] * rank) + effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] + pads = get_attribute(node, "pads") + if pads is None: + pads = [0] * (2 * rank) + auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") + if auto_pad != "VALID" and auto_pad != "NOTSET": + try: + residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] + total_pads = [ + max(0, (k - s) if r == 0 else (k - r)) + for k, s, r in zip(effective_kernel_shape, strides, residual) + ] + except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational + total_pads = [ + max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) + ] # assuming no residual if sympy throws error + elif auto_pad == "VALID": + total_pads = [] + else: + total_pads = [0] * rank + else: + assert len(pads) == 2 * rank + total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] + + ceil_mode = get_attribute(node, "ceil_mode", 0) + for i in range(rank): + effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)] + if len(total_pads) > 0: + effective_input_size = effective_input_size + total_pads[i] + if ceil_mode: + strided_kernel_positions = sympy.ceiling( + (effective_input_size - effective_kernel_shape[i]) / strides[i] + ) + else: + strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] + sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1 + return sympy_shape + + def _check_merged_dims(self, dims, allow_broadcast=True): + if allow_broadcast: + dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] + if not all([d == dims[0] for d in dims]): + self._add_suggested_merge(dims, apply=True) + + def _compute_matmul_shape(self, node, output_dtype=None): + lhs_shape = self._get_shape(node, 0) + rhs_shape = self._get_shape(node, 1) + lhs_rank = len(lhs_shape) + rhs_rank = len(rhs_shape) + lhs_reduce_dim = 0 + rhs_reduce_dim = 0 + assert lhs_rank > 0 and rhs_rank > 0 + if lhs_rank == 1 and rhs_rank == 1: + new_shape = [] + elif lhs_rank == 1: + rhs_reduce_dim = -2 + new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] + elif rhs_rank == 1: + lhs_reduce_dim = -1 + new_shape = lhs_shape[:lhs_reduce_dim] + else: + lhs_reduce_dim = -1 + rhs_reduce_dim = -2 + new_shape = [*self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), lhs_shape[-2], rhs_shape[-1]] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], + allow_broadcast=False, + ) + if output_dtype is None: + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): + """ + update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches + """ + dst_tensor_type = ( + dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type + ) + src_tensor_type = ( + src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type + ) + if dst_tensor_type.elem_type != src_tensor_type.elem_type: + node_id = node.name if node.name else node.op_type + raise ValueError( + f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " + f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " + f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" + ) + if dst_tensor_type.HasField("shape"): + for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): + if ds[0] != ds[1]: + # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type + # for sequence_type, clear the dimension + new_dim = onnx.TensorShapeProto.Dimension() + if not is_sequence(dst_type): + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di)) + dst_tensor_type.shape.dim[di].CopyFrom(new_dim) + else: + dst_tensor_type.CopyFrom(src_tensor_type) + + def _infer_ArrayFeatureExtractor(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:-1] + indices_shape, + ) + ) + + def _infer_symbolic_compute_ops(self, node): + funcs = { + "Add": lambda l: l[0] + l[1], # noqa: E741 + "Div": lambda l: ( # noqa: E741 + int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] + ), # integer div in sympy + "Equal": lambda l: l[0] == l[1], # noqa: E741 + "Floor": lambda l: sympy.floor(l[0]), # noqa: E741 + "Max": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) < -self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) + ), + "Min": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) > self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) + ), + "Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], # noqa: E741 + "Sub": lambda l: l[0] - l[1], # noqa: E741 + "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741 + "Neg": lambda l: -l[0], # noqa: E741 + } + assert node.op_type in funcs + self._compute_on_sympy_data(node, funcs[node.op_type]) + + def _infer_Cast(self, node): # noqa: N802 + self._pass_on_sympy_data(node) + + def _infer_CategoryMapper(self, node): # noqa: N802 + input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + if input_type == onnx.TensorProto.STRING: + output_type = onnx.TensorProto.INT64 + else: + output_type = onnx.TensorProto.STRING + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0))) + + def _infer_Compress(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + # create a new symbolic dimension for Compress output + compress_len = str(self._new_symbolic_dim_from_output(node)) + axis = get_attribute(node, "axis") + if axis is None: + # when axis is not specified, input is flattened before compress so output is 1D + output_shape = [compress_len] + else: + output_shape = input_shape + output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + def _infer_Concat(self, node): # noqa: N802 + if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]): + values = self._get_int_or_float_values(node) + if all([v is not None for v in values]): + assert get_attribute(node, "axis") == 0 + self.sympy_data_[node.output[0]] = [] + for i in range(len(node.input)): + value = values[i] + if isinstance(value, list): + self.sympy_data_[node.output[0]].extend(value) + else: + self.sympy_data_[node.output[0]].append(value) + + sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape)) + for i_idx in range(1, len(node.input)): + input_shape = self._get_sympy_shape(node, i_idx) + if input_shape: + sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] + self._update_computed_dims(sympy_shape) + # merge symbolic dims for non-concat axes + for d in range(len(sympy_shape)): + if d == axis: + continue + dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)] + if all([d == dims[0] for d in dims]): + continue + merged = self._merge_symbols(dims) + if type(merged) == str: # noqa: E721 + sympy_shape[d] = self.symbolic_dims_[merged] if merged else None + else: + sympy_shape[d] = merged + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_ConcatFromSequence(self, node): # noqa: N802 + seq_shape = self._get_shape(node, 0) + new_axis = 1 if get_attribute(node, "new_axis") else 0 + axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) + concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) + new_shape = seq_shape + if new_axis: + new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] + else: + new_shape[axis] = concat_dim + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Constant(self, node): # noqa: N802 + t = get_attribute(node, "value") + self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) + + def _infer_ConstantOfShape(self, node): # noqa: N802 + sympy_shape = self._get_int_or_float_values(node)[0] + vi = self.known_vi_[node.output[0]] + if sympy_shape is not None: + if type(sympy_shape) != list: # noqa: E721 + sympy_shape = [sympy_shape] + self._update_computed_dims(sympy_shape) + # update sympy data if output type is int, and shape is known + if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): + self.sympy_data_[node.output[0]] = np.ones( + [int(x) for x in sympy_shape], dtype=np.int64 + ) * numpy_helper.to_array(get_attribute(node, "value", 0)) + else: + # create new dynamic shape + # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length + sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node) + + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_Conv(self, node): # noqa: N802 + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_NhwcConv(self, node): # noqa: N802 + sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_DequantizeLinear(self, node): # noqa: N802 + # Get the output data type from the scale input (index 1, required). + output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_QuantizeLinear(self, node): # noqa: N802 + # Get the output data type from the zero-point input (index 2, optional). + # Otherwise, default to uint8 + output_dtype = onnx.TensorProto.UINT8 + if len(node.input) > 2 and node.input[2]: + output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_Einsum(self, node): # noqa: N802 + # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 + equation = get_attribute(node, "equation") + equation = equation.replace(b" ", b"") + mid_index = equation.find(b"->") + left_equation = equation[:mid_index] if mid_index != -1 else equation + + num_operands = 0 + num_ellipsis = 0 + num_ellipsis_indices = 0 + + letter_to_dim = {} + + terms = left_equation.split(b",") + for term in terms: + ellipsis_index = term.find(b"...") + shape = self._get_shape(node, num_operands) + rank = len(shape) + if ellipsis_index != -1: + if num_ellipsis == 0: + num_ellipsis_indices = rank - len(term) + 3 + num_ellipsis = num_ellipsis + 1 + for i in range(1, rank + 1): + letter = term[-i] + if letter != 46: # letter != b'.' + dim = shape[-i] + if letter not in letter_to_dim: + letter_to_dim[letter] = dim + elif type(dim) != sympy.Symbol: + letter_to_dim[letter] = dim + num_operands = num_operands + 1 + + new_sympy_shape = [] + from collections import OrderedDict + + num_letter_occurrences = OrderedDict() + if mid_index != -1: + right_equation = equation[mid_index + 2 :] + right_ellipsis_index = right_equation.find(b"...") + if right_ellipsis_index != -1: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in right_equation: + if c != 46: # c != b'.' + new_sympy_shape.append(letter_to_dim[c]) + else: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in left_equation: + if c != 44 and c != 46: # c != b',' and c != b'.': + if c in num_letter_occurrences: + num_letter_occurrences[c] = num_letter_occurrences[c] + 1 + else: + num_letter_occurrences[c] = 1 + for key, value in num_letter_occurrences.items(): + if value == 1: + new_sympy_shape.append(letter_to_dim[key]) + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape)) + + def _infer_Expand(self, node): # noqa: N802 + expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) + if expand_to_shape is not None: + # new_shape's dim can come from shape value + self._update_computed_dims(expand_to_shape) + shape = self._get_shape(node, 0) + new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Gather(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:axis] + indices_shape + data_shape[axis + 1 :], + ) + ) + # for 1D input, do some sympy compute + if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0: + idx = self._try_get_value(node, 1) + if idx is not None: + data = self.sympy_data_[node.input[0]] + if type(data) == list: # noqa: E721 + if type(idx) == np.ndarray and len(idx.shape) == 1: + self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx] + else: + self.sympy_data_[node.output[0]] = data[int(idx)] + else: + assert idx == 0 or idx == -1 + self.sympy_data_[node.output[0]] = data + + def _infer_GatherElements(self, node): # noqa: N802 + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + indices_shape, + ) + ) + + def _infer_GatherND(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + data_rank = len(data_shape) + indices_shape = self._get_shape(node, 1) + len(indices_shape) + last_index_dimension = indices_shape[-1] + assert is_literal(last_index_dimension) and last_index_dimension <= data_rank + new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_If(self, node): # noqa: N802 + # special case for constant condition, in case there are mismatching shape from the non-executed branch + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] + cond = self._try_get_value(node, 0) + if cond is not None: + if as_scalar(cond) > 0: + subgraphs[1].CopyFrom(subgraphs[0]) + else: + subgraphs[0].CopyFrom(subgraphs[1]) + + for i_sub, subgraph in enumerate(subgraphs): + subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) + for i_out in range(len(node.output)): + vi = self.known_vi_[node.output[i_out]] + if i_sub == 0: + vi.CopyFrom(subgraph.output[i_out]) + vi.name = node.output[i_out] + else: + self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) + + # pass on sympy data from subgraph, if cond is constant + if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1): + if subgraph.output[i_out].name in subgraph_infer.sympy_data_: + self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name] + + def _infer_Loop(self, node): # noqa: N802 + subgraph = get_attribute(node, "body") + assert len(subgraph.input) == len(node.input) + num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition + # when sequence_type is used as loop carried input + # needs to run subgraph infer twice if the tensor shape in sequence contains None + for i, si in enumerate(subgraph.input): + si_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + si.name = si_name + + self._onnx_infer_subgraph(node, subgraph) + + # check subgraph input/output for shape changes in loop carried variables + # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a) + # for sequence_type, propagate from output to input + need_second_infer = False + for i_out in range(1, num_loop_carried + 1): + so = subgraph.output[i_out] + so_shape = get_shape_from_value_info(so) + if is_sequence(so.type): + if so_shape and None in so_shape: + # copy shape from output to input + # note that loop input is [loop_len, cond, input_0, input_1, ...] + # while loop output is [cond, output_0, output_1, ...] + subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type) + need_second_infer = True + else: + si = subgraph.input[i_out + 1] + si_shape = get_shape_from_value_info(si) + for di, dims in enumerate(zip(si_shape, so_shape)): + if dims[0] != dims[1]: + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di)) + si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + need_second_infer = True + + if need_second_infer: + if self.verbose_ > 2: + logger.debug( + "Rerun Loop: {}({}...), because of sequence in loop carried variables".format( + node.name, node.output[0] + ) + ) + self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) + + # create a new symbolic dimension for iteration dependent dimension + loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) + for i in range(len(node.output)): + vi = self.known_vi_[node.output[i]] + vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output + if i >= num_loop_carried: + assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type + subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim + vi.type.tensor_type.shape.ClearField("dim") + vi_dim = vi.type.tensor_type.shape.dim + vi_dim.add().dim_param = loop_iter_dim + vi_dim.extend(list(subgraph_vi_dim)) + vi.name = node.output[i] + + def _infer_MatMul(self, node): # noqa: N802 + self._compute_matmul_shape(node) + + def _infer_MatMulInteger(self, node): # noqa: N802 + self._compute_matmul_shape(node, onnx.TensorProto.INT32) + + def _infer_NonMaxSuppression(self, node): # noqa: N802 + selected = str(self._new_symbolic_dim_from_output(node)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3])) + + def _infer_NonZero(self, node): # noqa: N802 + input_rank = self._get_shape_rank(node, 0) + # create a new symbolic dimension for NonZero output + nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) + + def _infer_OneHot(self, node): # noqa: N802 + sympy_shape = self._get_sympy_shape(node, 0) + depth = self._try_get_value(node, 1) + axis = get_attribute(node, "axis", -1) + axis = handle_negative_axis(axis, len(sympy_shape) + 1) + new_shape = get_shape_from_sympy_shape( + sympy_shape[:axis] + + [self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth] + + sympy_shape[axis:] + ) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[2]].type.tensor_type.elem_type, + new_shape, + ) + ) + + def _infer_Pad(self, node): # noqa: N802 + if get_opset(self.out_mp_) <= 10: + pads = get_attribute(node, "pads") + else: + pads = self._try_get_value(node, 1) + + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + + if pads is not None: + assert len(pads) == 2 * rank + new_sympy_shape = [ + d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) + ] + self._update_computed_dims(new_sympy_shape) + else: + # dynamic pads, create new symbolic dimensions + new_sympy_shape = self._new_symbolic_shape(rank, node) + output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)) + ) + + def _infer_Pool(self, node): # noqa: N802 + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + for o in node.output: + if not o: + continue + vi = self.known_vi_[o] + vi.CopyFrom( + helper.make_tensor_value_info( + o, + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_aten_bitwise_or(self, node): + shape0 = self._get_shape(node, 0) + shape1 = self._get_shape(node, 1) + new_shape = self._broadcast_shapes(shape0, shape1) + t0 = self.known_vi_[node.input[0]] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape)) + + def _infer_aten_diagonal(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + offset = self._try_get_value(node, 1) + dim1 = self._try_get_value(node, 2) + dim2 = self._try_get_value(node, 3) + + assert offset is not None and dim1 is not None and dim2 is not None + dim1 = handle_negative_axis(dim1, rank) + dim2 = handle_negative_axis(dim2, rank) + + new_shape = [] + for dim, val in enumerate(sympy_shape): + if dim not in [dim1, dim2]: + new_shape.append(val) + + shape1 = sympy_shape[dim1] + shape2 = sympy_shape[dim2] + if offset >= 0: + diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) + else: + diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) + new_shape.append(diag_shape) + + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_shape), + ) + ) + + def _infer_aten_multinomial(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + assert rank in [1, 2] + num_samples = self._try_get_value(node, 1) + di = rank - 1 + last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) + output_shape = sympy_shape[:-1] + [last_dim] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.INT64, + get_shape_from_sympy_shape(output_shape), + ) + ) + + def _infer_aten_pool2d(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + assert len(sympy_shape) == 4 + sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]] + self._update_computed_dims(sympy_shape) + for i, o in enumerate(node.output): + if not o: + continue + vi = self.known_vi_[o] + elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape))) + + def _infer_aten_minmax(self, node): + vi = self.known_vi_[node.output[0]] + if len(node.input) == 1: + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, [] + ) + ) + else: + assert len(node.input) == 3 + keepdim = self._try_get_value(node, 2) + assert keepdim is not None # can only handle known keepdim case. + dim = self._try_get_value(node, 1) + if dim is None: + rank = self._get_shape_rank(node, 0) + output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) + else: + shape = self._get_sympy_shape(node, 0) + dim = handle_negative_axis(dim, len(shape)) + output_shape = shape[:dim] + if keepdim: + output_shape += [1] + output_shape += shape[dim + 1 :] + + output_shape = get_shape_from_sympy_shape(output_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, output_shape + ) + ) + vi1 = self.known_vi_[node.output[1]] + vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape)) + + def _infer_aten_unfold(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + dimension = self._try_get_value(node, 1) + size = self._try_get_value(node, 2) + step = self._try_get_value(node, 3) + if dimension is not None and size is not None and step is not None: + assert dimension < len(sympy_shape) + sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 + sympy_shape.append(size) + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank + 1, node) + self._update_computed_dims(sympy_shape) + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + + def _infer_aten_argmax(self, node): + new_shape = None + if not node.input[1]: + # The argmax of the flattened input is returned. + new_shape = [] + else: + dim = self._try_get_value(node, 1) + keepdim = self._try_get_value(node, 2) + if keepdim is not None: + sympy_shape = self._get_sympy_shape(node, 0) + if dim is not None: + dim = handle_negative_axis(dim, len(sympy_shape)) + if keepdim: + sympy_shape[dim] = 1 + else: + del sympy_shape[dim] + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) + self._update_computed_dims(sympy_shape) + new_shape = get_shape_from_sympy_shape(sympy_shape) + if node.output[0] and new_shape is not None: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) + + def _infer_aten_group_norm(self, node): + self._propagate_shape_and_type(node) + input_shape = self._get_shape(node, 0) + N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806 + group = self._try_get_value(node, 6) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + for i in [1, 2]: + if node.output[i]: + vi = self.known_vi_[node.output[i]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[i], + output_dtype, + [ + N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)), + ( + as_scalar(group) + if group is not None + else str(self._new_symbolic_dim_from_output(node, i, 1)) + ), + ], + ) + ) + + def _infer_aten_upsample(self, node): + new_shape = None + input_shape = self._get_shape(node, 0) + if input_shape is not None: + new_shape = input_shape[:2] + output_size = self._try_get_value(node, 1) + if output_size is not None: + new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size] + else: + rank = len(input_shape) + new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)] + if node.output[0] and new_shape is not None: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _infer_BatchNormalization(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop + for i in [1, 2, 3, 4]: + if i < len(node.output) and node.output[i]: + # all of these parameters have the same shape as the 1st input + self._propagate_shape_and_type(node, input_index=1, output_index=i) + + def _infer_Range(self, node): # noqa: N802 + vi = self.known_vi_[node.output[0]] + input_data = self._get_int_or_float_values(node) + if all([i is not None for i in input_data]): + start = as_scalar(input_data[0]) + limit = as_scalar(input_data[1]) + delta = as_scalar(input_data[2]) + new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)] + else: + new_sympy_shape = [self._new_symbolic_dim_from_output(node)] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_ReduceSum(self, node): # noqa: N802 + keep_dims = get_attribute(node, "keepdims", 1) + if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: + # ReduceSum changes axes to input[1] in opset 13 + axes = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if axes is None: + assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)), + ) + ) + else: + shape = self._get_shape(node, 0) + output_shape = [] + axes = [handle_negative_axis(a, len(shape)) for a in axes] + for i, d in enumerate(shape): + if i in axes: + if keep_dims: + output_shape.append(1) + else: + output_shape.append(d) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + def _infer_ReduceProd(self, node): # noqa: N802 + axes = get_attribute(node, "axes") + keep_dims = get_attribute(node, "keepdims", 1) + if keep_dims == 0 and axes == [0]: + data = self._get_int_or_float_values(node)[0] + if data is not None: + self.sympy_data_[node.output[0]] = sympy_reduce_product(data) + + def _infer_RelativePositionBias(self, node): # noqa: N802 + seq_len = self._try_get_value(node, 1) + real_seq_len = self._try_get_value(node, 2) + if seq_len is None or real_seq_len is None: + return + num_heads = self._get_sympy_shape(node, 0)[1] + + new_shape = [1, num_heads, str(seq_len), str(real_seq_len)] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + + def _infer_Reshape(self, node): # noqa: N802 + shape_value = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if shape_value is None: + shape_shape = self._get_shape(node, 1) + assert len(shape_shape) == 1 + shape_rank = shape_shape[0] + assert is_literal(shape_rank) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), + ) + ) + else: + input_sympy_shape = self._get_sympy_shape(node, 0) + total = 1 + for d in input_sympy_shape: + total = total * d + new_sympy_shape = [] + deferred_dim_idx = -1 + non_deferred_size = 1 + for i, d in enumerate(shape_value): + if type(d) == sympy.Symbol: + new_sympy_shape.append(d) + elif d == 0: + new_sympy_shape.append(input_sympy_shape[i]) + non_deferred_size = non_deferred_size * input_sympy_shape[i] + else: + new_sympy_shape.append(d) + if d == -1: + deferred_dim_idx = i + elif d != 0: + non_deferred_size = non_deferred_size * d + + assert new_sympy_shape.count(-1) < 2 + if -1 in new_sympy_shape: + new_dim = total // non_deferred_size + new_sympy_shape[deferred_dim_idx] = new_dim + + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + self._pass_on_sympy_data(node) + + def _infer_Resize(self, node): # noqa: N802 + vi = self.known_vi_[node.output[0]] + input_sympy_shape = self._get_sympy_shape(node, 0) + if get_opset(self.out_mp_) <= 10: + scales = self._try_get_value(node, 1) + if scales is not None: + new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + else: + roi = self._try_get_value(node, 1) + scales = self._try_get_value(node, 2) + sizes = self._try_get_value(node, 3) + if sizes is not None: + new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes] + self._update_computed_dims(new_sympy_shape) + elif scales is not None: + rank = len(scales) + if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize": + assert len(roi) == 2 * rank + roi_start = list(roi)[:rank] + roi_end = list(roi)[rank:] + else: + roi_start = [0] * rank + roi_end = [1] * rank + scales = list(scales) + new_sympy_shape = [ + sympy.simplify(sympy.floor(d * (end - start) * scale)) + for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales) + ] + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) + + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_Scan(self, node): # noqa: N802 + subgraph = get_attribute(node, "body") + num_scan_inputs = get_attribute(node, "num_scan_inputs") + scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) + num_scan_states = len(node.input) - num_scan_inputs + scan_input_axes = [ + handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states)) + for i, ax in enumerate(scan_input_axes) + ] + # We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer, + # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. + assert len(subgraph.input) >= len(node.input) + subgraph_inputs = subgraph.input[: len(node.input)] + for i, si in enumerate(subgraph_inputs): + subgraph_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + if i >= num_scan_states: + scan_input_dim = si.type.tensor_type.shape.dim + scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) + si.name = subgraph_name + self._onnx_infer_subgraph(node, subgraph) + num_scan_outputs = len(node.output) - num_scan_states + scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) + scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] + for i, o in enumerate(node.output): + vi = self.known_vi_[o] + if i >= num_scan_states: + shape = get_shape_from_type_proto(subgraph.output[i].type) + new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1) + shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] + vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape)) + else: + vi.CopyFrom(subgraph.output[i]) + vi.name = o + + def _infer_ScatterElements(self, node): # noqa: N802 + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape, + ) + ) + + def _infer_SequenceAt(self, node): # noqa: N802 + # need to create new symbolic dimension if sequence shape has None: + seq_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + if seq_shape is not None: + for di, d in enumerate(seq_shape): + if d is not None: + continue + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di)) + vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + + def _infer_SequenceInsert(self, node): # noqa: N802 + # workaround bug in onnx's shape inference + vi_seq = self.known_vi_[node.input[0]] + vi_tensor = self.known_vi_[node.input[1]] + vi_out_seq = self.known_vi_[node.output[0]] + vi_out_seq.CopyFrom(vi_seq) + vi_out_seq.name = node.output[0] + self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) + + def _infer_Shape(self, node): # noqa: N802 + self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) + + def _infer_Size(self, node): # noqa: N802 + sympy_shape = self._get_sympy_shape(node, 0) + self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) + self.known_vi_[node.output[0]].CopyFrom( + helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) + ) + + def _infer_Slice(self, node): # noqa: N802 + # SymPy fails to prove that `x_0 + ... + x_n >= 0` if one of `x_i` is a `sympy.Min(a, b)`, + # even when the relation holds for both `a` and `b`. + # + # When given `expr` of form `min(a, b) + ...`, this function returns `[a + ..., b + ...]`, + # so that we can prove inequalities for both expressions separately. + # + # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`. + def flatten_min(expr): + assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" + min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)] + if len(min_positions) == 1: + min_pos = min_positions[0] + + def replace_min_with_arg(arg_idx): + replaced = list(expr.args) + assert isinstance( + replaced[min_pos], sympy.Min + ), f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}" + assert ( + len(replaced[min_pos].args) == 2 + ), f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}" + replaced[min_pos] = replaced[min_pos].args[arg_idx] + return sympy.Add(*replaced) + + return [ + replace_min_with_arg(0), + replace_min_with_arg(1), + ] + return [expr] + + def less_equal(x, y): + try: + return bool(x <= y) + except TypeError: + pass + try: + return bool(y >= x) + except TypeError: + pass + try: + return bool(-x >= -y) + except TypeError: + pass + try: + return bool(-y <= -x) + except TypeError: + pass + try: + return bool(y - x >= 0) + except TypeError: + # the last attempt; this may raise TypeError + return all(bool(d >= 0) for d in flatten_min(y - x)) + + def handle_negative_index(index, bound): + """normalizes a negative index to be in [0, bound)""" + try: + if not less_equal(0, index): + if is_literal(index) and index <= -self.int_max_: + # this case is handled separately + return index + return bound + index + except TypeError: + logger.warning(f"Cannot determine if {index} < 0") + return index + + if get_opset(self.out_mp_) <= 9: + axes = get_attribute(node, "axes") + starts = get_attribute(node, "starts") + ends = get_attribute(node, "ends") + if not axes: + axes = list(range(len(starts))) + steps = [1] * len(axes) + else: + starts = as_list(self._try_get_value(node, 1), keep_none=True) + ends = as_list(self._try_get_value(node, 2), keep_none=True) + axes = self._try_get_value(node, 3) + steps = self._try_get_value(node, 4) + if axes is None and not (starts is None and ends is None): + axes = list(range(0, len(starts if starts is not None else ends))) + if steps is None and not (starts is None and ends is None): + steps = [1] * len(starts if starts is not None else ends) + axes = as_list(axes, keep_none=True) + steps = as_list(steps, keep_none=True) + + new_sympy_shape = self._get_sympy_shape(node, 0) + if starts is None or ends is None: + if axes is None: + for i in range(len(new_sympy_shape)): + new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) + else: + new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) + for i in axes: + new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) + else: + for i, s, e, t in zip(axes, starts, ends, steps): + e = handle_negative_index(e, new_sympy_shape[i]) # noqa: PLW2901 + if is_literal(e): + if e >= self.int_max_: + e = new_sympy_shape[i] # noqa: PLW2901 + elif e <= -self.int_max_: + e = 0 if s > 0 else -1 # noqa: PLW2901 + elif is_literal(new_sympy_shape[i]): + if e < 0: + e = max(0, e + new_sympy_shape[i]) # noqa: PLW2901 + e = min(e, new_sympy_shape[i]) # noqa: PLW2901 + else: + if e > 0: + e = ( # noqa: PLW2901 + sympy.Min(e, new_sympy_shape[i]) if e > 1 else e + ) # special case for slicing first to make computation easier + else: + if is_literal(new_sympy_shape[i]): + e = sympy.Min(e, new_sympy_shape[i]) # noqa: PLW2901 + else: + try: + if not less_equal(e, new_sympy_shape[i]): + e = new_sympy_shape[i] # noqa: PLW2901 + except Exception: + logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal") + e = new_sympy_shape[i] # noqa: PLW2901 + + s = handle_negative_index(s, new_sympy_shape[i]) # noqa: PLW2901 + if is_literal(new_sympy_shape[i]) and is_literal(s): + s = max(0, min(s, new_sympy_shape[i])) # noqa: PLW2901 + + new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) + + self._update_computed_dims(new_sympy_shape) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + # handle sympy_data if needed, for slice in shape computation + if ( + node.input[0] in self.sympy_data_ + and [0] == axes + and starts is not None + and len(starts) == 1 + and ends is not None + and len(ends) == 1 + and steps is not None + and len(steps) == 1 + ): + input_sympy_data = self.sympy_data_[node.input[0]] + if type(input_sympy_data) == list or ( # noqa: E721 + type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 + ): + self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]] + + def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 + vi = self.known_vi_[node.output[0]] + elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + # If output type is explicit specified in attribute, we use it as output tensor type. + specified_output_type = get_attribute(node, "output_type", None) + if specified_output_type is not None: + elem_type = specified_output_type + + vi.type.tensor_type.elem_type = elem_type + vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + + if len(node.output) > 1: + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape)) + + def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 + input_sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'split' are provided as attribute or via 2nd input + if op_set < 13: + split = get_attribute(node, "split") + assert self._try_get_value(node, 1) is None + else: + split = self._try_get_value(node, 1) + assert get_attribute(node, "split") is None + + if split is None: + num_outputs = len(node.output) + split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs + self._update_computed_dims(split) + else: + split = [sympy.Integer(s) for s in split] + + for i_o in range(len(split)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom( + make_value_info_func( + node.output[i_o], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]), + ) + ) + self.known_vi_[vi.name] = vi + + def _infer_Split(self, node): # noqa: N802 + self._infer_Split_Common(node, helper.make_tensor_value_info) + + def _infer_SplitToSequence(self, node): # noqa: N802 + self._infer_Split_Common(node, helper.make_sequence_value_info) + + def _infer_Squeeze(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, "axes") + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, "axes") is None + + if axes is None: + # No axes have been provided (neither via attribute nor via input). + # In this case the 'Shape' op should remove all axis with dimension 1. + # For symbolic dimensions we guess they are !=1. + output_shape = [s for s in input_shape if s != 1] + if self.verbose_ > 0: + symbolic_dimensions = [s for s in input_shape if type(s) != int] # noqa: E721 + if len(symbolic_dimensions) > 0: + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" + ) + else: + axes = [handle_negative_axis(a, len(input_shape)) for a in axes] + output_shape = [] + for i in range(len(input_shape)): + if i not in axes: + output_shape.append(input_shape[i]) + else: + assert input_shape[i] == 1 or type(input_shape[i]) != int # noqa: E721 + if self.verbose_ > 0 and type(input_shape[i]) != int: # noqa: E721 + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." + ) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + self._pass_on_sympy_data(node) + + def _infer_Tile(self, node): # noqa: N802 + repeats_value = self._try_get_value(node, 1) + new_sympy_shape = [] + if repeats_value is not None: + input_sympy_shape = self._get_sympy_shape(node, 0) + for i, d in enumerate(input_sympy_shape): + new_dim = d * repeats_value[i] + new_sympy_shape.append(new_dim) + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) + + def _infer_TopK(self, node): # noqa: N802 + rank = self._get_shape_rank(node, 0) + axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) + new_shape = self._get_shape(node, 0) + + if get_opset(self.out_mp_) <= 9: + k = get_attribute(node, "k") + else: + k = self._get_int_or_float_values(node)[1] + + if k is None: + k = self._new_symbolic_dim_from_output(node) + else: + k = as_scalar(k) + + if type(k) in [int, str]: + new_shape[axis] = k + else: + new_sympy_shape = self._get_sympy_shape(node, 0) + new_sympy_shape[axis] = k + self._update_computed_dims( + new_sympy_shape + ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape + new_shape = get_shape_from_sympy_shape(new_sympy_shape) + + for i_o in range(len(node.output)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape)) + + def _infer_Transpose(self, node): # noqa: N802 + if node.input[0] in self.sympy_data_: + data_shape = self._get_shape(node, 0) + perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) + input_data = self.sympy_data_[node.input[0]] + self.sympy_data_[node.output[0]] = ( + np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist() + ) + + def _infer_Unsqueeze(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, "axes") + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, "axes") is None + + output_rank = len(input_shape) + len(axes) + axes = [handle_negative_axis(a, output_rank) for a in axes] + + input_axis = 0 + output_shape = [] + for i in range(output_rank): + if i in axes: + output_shape.append(1) + else: + output_shape.append(input_shape[input_axis]) + input_axis += 1 + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) + + self._pass_on_sympy_data(node) + + def _infer_ZipMap(self, node): # noqa: N802 + map_key_type = None + if get_attribute(node, "classlabels_int64s") is not None: + map_key_type = onnx.TensorProto.INT64 + elif get_attribute(node, "classlabels_strings") is not None: + map_key_type = onnx.TensorProto.STRING + + assert map_key_type is not None + new_vi = onnx.ValueInfoProto() + new_vi.name = node.output[0] + new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT + new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(new_vi) + + def _infer_Attention(self, node): # noqa: N802 + shape = self._get_shape(node, 0) + shape_weights = self._get_shape(node, 1) + shape_bias = self._try_get_shape(node, 2) + if shape_bias is not None: + assert len(shape_bias) == 1 + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] + if shape and len(shape) == 3: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[2] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(tripled_hidden_size, int): + shape[2] = int(tripled_hidden_size / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + if len(node.output) > 1: + # input shape: (batch_size, sequence_length, hidden_size) + # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) + # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) + # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length + input_shape = self._get_shape(node, 0) + past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] + mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] + + if past_shape and len(past_shape) == 5: + if mask_shape and len(mask_shape) in [2, 3]: + past_shape[3] = mask_shape[-1] + elif input_shape and len(input_shape) == 3: + if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): + past_shape[3] = input_shape[1] + past_shape[3] + else: + past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + # No past input but present output still exists + else: + num_heads = get_attribute(node, "num_heads") + head_size = input_shape[2] // num_heads + present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + + def _infer_GatedRelativePositionBias(self, node): # noqa: N802 + # When padding is removed: + # query_layer: (token_count, num_heads x head_size) + # token_offset: (batch_size, seq_len) + # Otherwise: + # query_layer: (batch_size, seq_len, num_heads x head_size) + # token_offset: None + # Output shape: (batch_size, num_heads, seq_len, seq_len) + num_heads = get_attribute(node, "num_heads") + + token_offset_shape = self._try_get_shape(node, 6) + if token_offset_shape is not None: + output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]] + else: + query_layer_shape = self._get_shape(node, 0) + assert query_layer_shape is not None and len(query_layer_shape) == 3 + output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_PackedAttention(self, node): # noqa: N802 + shape = self._get_shape(node, 0) + shape_weights = self._get_shape(node, 1) + shape_bias = self._try_get_shape(node, 2) + if shape_bias is not None: + assert len(shape_bias) == 1 + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] + if shape and len(shape) == 2: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[1] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(tripled_hidden_size, int): + shape[1] = int(tripled_hidden_size / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 + shape_value = self._try_get_shape(node, 2) + if shape_value is not None and len(shape_value) == 2: + output_shape = shape_value + else: + shape_query = self._get_shape(node, 0) + assert shape_query is not None and len(shape_query) == 4 + output_shape = [shape_query[0], shape_query[1] * shape_query[3]] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_RemovePadding(self, node): # noqa: N802 + shape = self._get_shape(node, 0) + if shape and len(shape) == 3: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]])) + + vi_token_offset = self.known_vi_[node.output[1]] + vi_token_offset.CopyFrom( + helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]]) + ) + + vi_cumulated_seq_len = self.known_vi_[node.output[2]] + vi_cumulated_seq_len.CopyFrom( + helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"]) + ) + + vi_max_seq_len = self.known_vi_[node.output[3]] + vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1])) + + def _infer_RestorePadding(self, node): # noqa: N802 + shape_input = self._get_shape(node, 0) + shape_token_offset = self._get_shape(node, 1) + if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2: + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + + output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_BiasGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_MultiHeadAttention(self, node): # noqa: N802 + # Output 0 has shape (batch_size, sequence_length, v_hidden_size) + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) + # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size) + # Packed KV: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) + # Input 2 nullptr + # Packed QKV: + # Input 0 (batch_size, sequence_length, num_heads, 3, head_size) + # Input 1 nullptr + # Input 2 nullptr + + query_shape = self._get_shape(node, 0) + total_sequence_length = None + output_dtype = None + if query_shape is not None: + if len(query_shape) == 3: + key_shape = self._try_get_shape(node, 1) + # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. + output_shape = query_shape + if key_shape is not None and len(key_shape) == 3: + value_shape = self._try_get_shape(node, 2) + if value_shape is not None and len(value_shape) == 3: + output_shape[2] = value_shape[2] + total_sequence_length = key_shape[1] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + elif len(query_shape) == 5: + if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): + output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]] + else: + output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"] + + total_sequence_length = query_shape[1] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + if len(node.output) > 1: + batch_size = query_shape[0] + num_heads = get_attribute(node, "num_heads") + + head_size = None + if len(query_shape) == 3: + head_size = ( + int(query_shape[2] / num_heads) + if isinstance(query_shape[2], int) + else f"{query_shape[2]}/{num_heads}" + ) + else: + head_size = query_shape[4] + + past_shape = self._try_get_shape(node, 6) + + if past_shape is not None: + if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int): + total_sequence_length = past_shape[2] + total_sequence_length + else: + total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" + + present_shape = [batch_size, num_heads, total_sequence_length, head_size] + + assert output_dtype is not None + if len(node.output) > 2 and node.output[1] and node.output[2]: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) + + def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802 + # Output 0 has shape (batch_size, 1, v_hidden_size) + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, 1, hidden_size) + # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size) + + query_shape = self._get_shape(node, 0) + if query_shape is not None: + output_shape = query_shape + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + assert output_dtype is not None + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + if len(node.output) > 2 and node.output[1] and node.output[2]: + past_shape = self._try_get_shape(node, 5) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + + def _infer_FastGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_Gelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_QuickGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_GemmFastGelu(self, node): # noqa: N802 + self._compute_matmul_shape(node) + + def _infer_GemmFloat8(self, node): # noqa: N802 + self._compute_matmul_shape(node) + + def _infer_LayerNormalization(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + if len(node.output) > 1: + axis = get_attribute(node, "axis") + if axis is None: + axis = -1 + x_shape = self._get_shape(node, 0) + if x_shape is not None: + rank = len(x_shape) + axis = handle_negative_axis(axis, rank) + mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] + mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16: + mean_dtype = onnx.TensorProto.FLOAT + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape)) + if len(node.output) > 2: + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape)) + + def _infer_LongformerAttention(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_EmbedLayerNormalization(self, node): # noqa: N802 + input_ids_shape = self._get_shape(node, 0) + word_embedding_shape = self._get_shape(node, 2) + assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 + output_shape = [*input_ids_shape, word_embedding_shape[1]] + + word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape)) + + if len(node.output) > 1 and node.output[1]: + mask_index_shape = [input_ids_shape[0]] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) + + if len(node.output) > 2: + # Optional output of add before layer normalization is done + # shape is same as the output + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape)) + + def _infer_SkipLayerNormalization(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + # If the SkipLayerNormalization node contains the optional + # output for inference, infer the shape and type for it too + if len(node.output) > 3: + self._propagate_shape_and_type(node, 0, 3) + + def _infer_GroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_PagedAttention(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_GroupQueryAttention(self, node): # noqa: N802 + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + past_shape = self._try_get_shape(node, 3) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + + if node.input[1] != "" and node.input[2] != "": + self._propagate_shape_and_type(node, 0, 0) + else: + # combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size) + assert node.input[1] == "" and node.input[2] == "" + num_heads = get_attribute(node, "num_heads") + kv_num_heads = get_attribute(node, "kv_num_heads") + query_shape = self._get_shape(node, 0) + if query_shape is not None: + hidden_size = query_shape[2] + if isinstance(hidden_size, int): + head_size = int(hidden_size / (num_heads + 2 * kv_num_heads)) + query_shape[2] = num_heads * head_size + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape)) + + def _infer_SkipGroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + + def _infer_BiasSplitGelu(self, node): # noqa: N802 + input_shape = self._get_shape(node, 0) + bias_shape = self._get_shape(node, 1) + if input_shape and bias_shape and isinstance(bias_shape[0], int): + output_shape = input_shape + output_shape[2] = int(bias_shape[0] / 2) + vi = self.known_vi_[node.output[0]] + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) + + def _infer_BiasAdd(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_RotaryEmbedding(self, node): # noqa: N802 + if len(node.output) == 1: + self._propagate_shape_and_type(node) + elif len(node.output) == 2: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output + elif len(node.output) == 3: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=1, output_index=1) + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output + + def _infer_PythonOp(self, node): # noqa: N802 + output_tensor_types = get_attribute(node, "output_tensor_types") + assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute." + output_tensor_ranks = get_attribute(node, "output_tensor_ranks") + assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute." + + from onnxruntime.capi._pybind_state import get_shape_inference_function + + func_name = get_attribute(node, "func_name").decode() + shape_inferer = get_shape_inference_function(func_name) + + # Set the context output separately. + # The first output is torch.autograd.Function''s context. + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) + + if shape_inferer is not None: + input_shapes = [] + input_dtypes = [] + for input_index in range(len(node.input)): + shape = self._get_shape(node, input_index) + input_shapes.append(shape) + input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type + input_dtypes.append(input_dtype) + output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) + assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( + f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " + f"but expected {len(node.output) - 1} outputs." + ) + for i in range(len(node.output) - 1): + output_index = i + 1 + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]) + ) + else: + # General shape inference for PythonOp. + # Outputs after torch.autograd.Function's context are tensors. + # We assume their ranks are fixed for different model inputs. + for i in range(len(node.output) - 1): + # Process the i-th tensor outputs. + vi = self.known_vi_[node.output[i + 1]] + sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) + shape = get_shape_from_sympy_shape(sympy_shape) + value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape) + vi.CopyFrom(value_info) + + def _propagate_shape_and_type(self, node, input_index=0, output_index=0): + shape = self._get_shape(node, input_index) + output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) + + def _is_none_dim(self, dim_value): + if type(dim_value) != str: # noqa: E721 + return False + if "unk__" not in dim_value: + return False + if dim_value in self.symbolic_dims_: + return False + return True + + def _is_shape_contains_none_dim(self, out_shape): + for out in out_shape: + if self._is_none_dim(out): + return out + return None + + def _infer_impl(self, start_sympy_data=None): + self.sympy_data_ = start_sympy_data or {} + self.out_mp_.graph.ClearField("value_info") + self._apply_suggested_merge(graph_input_only=True) + self.input_symbols_ = set() + for i in self.out_mp_.graph.input: + input_shape = get_shape_from_value_info(i) + if input_shape is None: + continue + + if is_sequence(i.type): + input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim + else: + input_dims = i.type.tensor_type.shape.dim + + for i_dim, dim in enumerate(input_shape): + if dim is None: + # some models use None for symbolic dim in input, replace it with a string + input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) + + self.input_symbols_.update([d for d in input_shape if type(d) == str]) # noqa: E721 + + for s in self.input_symbols_: + if s in self.suggested_merge_: + s_merge = self.suggested_merge_[s] + assert s_merge in self.symbolic_dims_ + self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] + else: + # Since inputs are not produced by other ops, we can assume positivity + self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True) + # create a temporary ModelProto for single node inference + # note that we remove initializer to have faster inference + # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways + self.tmp_mp_ = onnx.ModelProto() + self.tmp_mp_.CopyFrom(self.out_mp_) + self.tmp_mp_.graph.ClearField("initializer") + + # compute prerequesite for node for topological sort + # node with subgraphs may have dependency on implicit inputs, which will affect topological sort + prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph + + def get_prereq(node): + names = {i for i in node.input if i} + subgraphs = [] + if node.op_type == "If": + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] + elif node.op_type in ["Loop", "Scan"]: + subgraphs = [get_attribute(node, "body")] + for g in subgraphs: + g_outputs_and_initializers = {i.name for i in g.initializer} + g_prereq = set() + for n in g.node: + g_outputs_and_initializers.update(n.output) + for n in g.node: + g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers]) + names.update(g_prereq) + # remove subgraph inputs from g_prereq since those are local-only + for i in g.input: + if i.name in names: + names.remove(i.name) + return names + + for n in self.tmp_mp_.graph.node: + prereq_for_node[n.output[0]] = get_prereq(n) + + # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate + sorted_nodes = [] + sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)} + if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + # Loop/Scan will have some graph output in graph inputs, so don't do topological sort + sorted_nodes = self.out_mp_.graph.node + else: + while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + old_sorted_nodes_len = len(sorted_nodes) + for node in self.out_mp_.graph.node: + if (node.output[0] not in sorted_known_vi) and all( + [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i] + ): + sorted_known_vi.update(node.output) + sorted_nodes.append(node) + if old_sorted_nodes_len == len(sorted_nodes) and not all( + [o.name in sorted_known_vi for o in self.out_mp_.graph.output] + ): + raise Exception("Invalid model with cyclic graph") + + for node in sorted_nodes: + assert all([i in self.known_vi_ for i in node.input if i]) + self._onnx_infer_single_node(node) + known_aten_op = False + if node.op_type in self.dispatcher_: + self.dispatcher_[node.op_type](node) + elif node.op_type in ["ConvTranspose"]: + # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input + # before adding symbolic compute for them + # mark the output type as UNDEFINED to allow guessing of rank + vi = self.known_vi_[node.output[0]] + if len(vi.type.tensor_type.shape.dim) == 0: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + elif node.op_type == "ATen" and node.domain == "org.pytorch.aten": + for attr in node.attribute: + # TODO: Is overload_name needed? + if attr.name == "operator": + aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if aten_op_name in self.aten_op_dispatcher_: + known_aten_op = True + self.aten_op_dispatcher_[aten_op_name](node) + break + + if self.verbose_ > 2: + logger.debug(node.op_type + ": " + node.name) + for i, name in enumerate(node.input): + logger.debug( + " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "") + ) + + # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] + # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case + if node.op_type in [ + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Where", + "Sum", + ]: + vi = self.known_vi_[node.output[0]] + out_rank = len(get_shape_from_type_proto(vi.type)) + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] + for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] + if len(in_dims) > 1: + self._check_merged_dims(in_dims, allow_broadcast=True) + + for i_o in range(len(node.output)): + # Special cases: + # 1) We do not care about the training related outputs of SkipLayerNormalization + # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because + # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding + # contrib op + if ( + node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" + ) and i_o in [1, 2]: + continue + if node.op_type == "RotaryEmbedding" and len(node.output) > 1: + # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs + # generated by `export_modules_as_functions` + continue + + vi = self.known_vi_[node.output[i_o]] + out_type = vi.type + out_type_kind = out_type.WhichOneof("value") + + # do not process shape for non-tensors + if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: + if self.verbose_ > 2: + if out_type_kind == "sequence_type": + seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") + if seq_cls_type == "tensor_type": + logger.debug( + " {}: sequence of {} {}".format( + node.output[i_o], + str(get_shape_from_value_info(vi)), + onnx.TensorProto.DataType.Name( + vi.type.sequence_type.elem_type.tensor_type.elem_type + ), + ) + ) + else: + logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}") + else: + logger.debug(f" {node.output[i_o]}: {out_type_kind}") + continue + + out_shape = get_shape_from_value_info(vi) + out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED + if self.verbose_ > 2: + logger.debug( + " {}: {} {}".format( + node.output[i_o], + str(out_shape), + onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), + ) + ) + if node.output[i_o] in self.sympy_data_: + logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) + + # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain + if ( + out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) + ) or out_type_undefined: + if self.auto_merge_: + if node.op_type in [ + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Concat", + "Where", + "Sum", + "Equal", + "Less", + "Greater", + "LessOrEqual", + "GreaterOrEqual", + "Min", + "Max", + ]: + shapes = [self._get_shape(node, i) for i in range(len(node.input))] + if node.op_type in [ + "MatMul", + "MatMulInteger", + "MatMulInteger16", + ]: + if None in out_shape or self._is_shape_contains_none_dim(out_shape): + if None in out_shape: + idx = out_shape.index(None) + else: + idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] + # only support auto merge for MatMul for dim < rank-2 when rank > 2 + assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 + assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 + elif node.op_type == "Expand": + # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) + shapes = [ + self._get_shape(node, 0), + self._get_value(node, 1), + ] + else: + shapes = [] + + if shapes: + for idx in range(len(out_shape)): + if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): + continue + # note that the broadcasting rule aligns from right to left + # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] + if len(dim_idx) > 0: + self._add_suggested_merge( + [ + s[i] if is_literal(s[i]) else str(s[i]) + for s, i in zip(shapes, dim_idx) + if i >= 0 + ] + ) + self.run_ = True + else: + self.run_ = False + else: + self.run_ = False + + # create new dynamic dims for ops not handled by symbolic shape inference + if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op: + is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0) + if is_unknown_op: + # unknown op to ONNX, maybe from higher opset or other domain + # only guess the output rank from input 0 when using guess_output_rank option + out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1 + else: + # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape + out_rank = len(out_shape) + + if out_rank >= 0: + new_shape = self._new_symbolic_shape(out_rank, node, i_o) + if out_type_undefined: + # guess output data type from input vi if not defined + out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + else: + # otherwise, use original data type + out_dtype = vi.type.tensor_type.elem_type + vi.CopyFrom( + helper.make_tensor_value_info( + vi.name, + out_dtype, + get_shape_from_sympy_shape(new_shape), + ) + ) + + if self.verbose_ > 0: + if is_unknown_op: + logger.debug( + "Possible unknown op: {} node: {}, guessing {} shape".format( + node.op_type, node.name, vi.name + ) + ) + if self.verbose_ > 2: + logger.debug( + " {}: {} {}".format( + node.output[i_o], + str(new_shape), + vi.type.tensor_type.elem_type, + ) + ) + + self.run_ = True + continue # continue the inference after guess, no need to stop as no merge is needed + + if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: + logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) + logger.debug("node inputs:") + for i in node.input: + if i in self.known_vi_: + logger.debug(self.known_vi_[i]) + else: + logger.debug(f"not in known_vi_ for {i}") + logger.debug("node outputs:") + for o in node.output: + if o in self.known_vi_: + logger.debug(self.known_vi_[o]) + else: + logger.debug(f"not in known_vi_ for {o}") + if self.auto_merge_ and not out_type_undefined: + logger.debug("Merging: " + str(self.suggested_merge_)) + return False + + self.run_ = False + return True + + def _update_output_from_vi(self): + for output in self.out_mp_.graph.output: + if output.name in self.known_vi_: + output.CopyFrom(self.known_vi_[output.name]) + + @staticmethod + def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): + onnx_opset = get_opset(in_mp) + if (not onnx_opset) or onnx_opset < 7: + logger.warning("Only support models of onnx opset 7 and above.") + return None + symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) + all_shapes_inferred = False + symbolic_shape_inference._preprocess(in_mp) + while symbolic_shape_inference.run_: + all_shapes_inferred = symbolic_shape_inference._infer_impl() + symbolic_shape_inference._update_output_from_vi() + if not all_shapes_inferred: + onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) + raise Exception("Incomplete symbolic shape inference") + return symbolic_shape_inference.out_mp_ + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="The input model file") + parser.add_argument("--output", help="The output model file") + parser.add_argument( + "--auto_merge", + help="Automatically merge symbolic dims when confliction happens", + action="store_true", + default=False, + ) + parser.add_argument( + "--int_max", + help="maximum value for integer to be treated as boundless for ops like slice", + type=int, + default=2**31 - 1, + ) + parser.add_argument( + "--guess_output_rank", + help="guess output rank to be the same as input 0 for unknown ops", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbose", + help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed", + type=int, + default=0, + ) + parser.add_argument( + "--save_as_external_data", + help="Saving an ONNX model to external data", + action="store_true", + default=False, + ) + parser.add_argument( + "--all_tensors_to_one_file", + help="Saving all the external data to one file", + action="store_true", + default=False, + ) + parser.add_argument( + "--external_data_location", + help="The file location to save the external file", + default="./", + ) + parser.add_argument( + "--external_data_size_threshold", + help="The size threshold for external data", + type=int, + default=1024, + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + logger.info("input model: " + args.input) + if args.output: + logger.info("output model " + args.output) + logger.info("Doing symbolic shape inference...") + out_mp = SymbolicShapeInference.infer_shapes( + onnx.load(args.input), + args.int_max, + args.auto_merge, + args.guess_output_rank, + args.verbose, + ) + if args.output and out_mp: + if args.save_as_external_data: + onnx.save_model( + out_mp, + args.output, + save_as_external_data=True, + all_tensors_to_one_file=args.all_tensors_to_one_file, + location=args.external_data_location, + size_threshold=args.external_data_size_threshold, + convert_attribute=False, + ) + else: + onnx.save(out_mp, args.output) + logger.info("Done!")