Spaces:
Running
Running
/* | |
* SPDX-License-Identifier: Apache-2.0 | |
*/ | |
namespace ONNX_NAMESPACE { | |
ONNX_OPERATOR_SET_SCHEMA( | |
ReduceMax, | |
20, | |
OpSchema().FillUsing(ReduceOpGenerator("max", EMPTY_MIN, true, true, nullptr, nullptr, true))); | |
ONNX_OPERATOR_SET_SCHEMA( | |
ReduceMin, | |
20, | |
OpSchema().FillUsing(ReduceOpGenerator("min", EMPTY_MAX, true, true, nullptr, nullptr, true))); | |
ONNX_OPERATOR_SET_SCHEMA(ReduceSum, 13, OpSchema().FillUsing(ReduceOpDynamicAxes("sum", EMPTY_ZERO))); | |
const char* reduce_sum_square_func_body = R"ONNX( | |
{ | |
data_square = Mul(data, data) | |
reduced = ReduceSum<keepdims: int = @keepdims>(data_square, axes) | |
} | |
)ONNX"; | |
ONNX_OPERATOR_SET_SCHEMA( | |
ReduceSumSquare, | |
18, | |
OpSchema().FillUsing(ReduceFunctionOp("sum square", EMPTY_ZERO, reduce_sum_square_func_body))); | |
ONNX_OPERATOR_SET_SCHEMA(ReduceMean, 18, OpSchema().FillUsing(ReduceOpDynamicAxes("mean", EMPTY_UNDEFINED))); | |
ONNX_OPERATOR_SET_SCHEMA(ReduceProd, 18, OpSchema().FillUsing(ReduceOpDynamicAxes("product", EMPTY_ONE))); | |
const char* reduce_log_sum_func_body = R"ONNX( | |
{ | |
reduced_sum = ReduceSum<keepdims: int = @keepdims>(data, axes) | |
reduced = Log (reduced_sum) | |
} | |
)ONNX"; | |
ONNX_OPERATOR_SET_SCHEMA( | |
ReduceLogSum, | |
18, | |
OpSchema().FillUsing(ReduceFunctionOp("log sum", EMPTY_MINUS_INF, reduce_log_sum_func_body))); | |
const char* reduce_log_sum_exp_func_body = R"ONNX( | |
{ | |
data_double = Cast<to = 11>(data) | |
data_exp = Exp (data_double) | |
reduced_sum = ReduceSum<keepdims: int = @keepdims>(data_exp, axes) | |
reduced_double = Log (reduced_sum) | |
reduced = CastLike(reduced_double, data) | |
} | |
)ONNX"; | |
ONNX_OPERATOR_SET_SCHEMA( | |
ReduceLogSumExp, | |
18, | |
OpSchema().FillUsing(ReduceFunctionOp("log sum exponent", EMPTY_MINUS_INF, reduce_log_sum_exp_func_body))); | |
const char* reduce_l1_func_body = R"ONNX( | |
{ | |
data_abs = Abs(data) | |
reduced = ReduceSum<keepdims: int = @keepdims>(data_abs, axes) | |
} | |
)ONNX"; | |
ONNX_OPERATOR_SET_SCHEMA( | |
ReduceL1, | |
18, | |
OpSchema().FillUsing(ReduceFunctionOp("L1 norm", EMPTY_ZERO, reduce_l1_func_body))); | |
const char* reduce_l2_func_body = R"ONNX( | |
{ | |
data_square = Mul(data, data) | |
sum_square = ReduceSum<keepdims: int = @keepdims>(data_square, axes) | |
sum_square_dbl = Cast <to = 1>(sum_square) | |
sqrt = Sqrt(sum_square_dbl) | |
reduced = CastLike(sqrt, data) | |
} | |
)ONNX"; | |
ONNX_OPERATOR_SET_SCHEMA( | |
ReduceL2, | |
18, | |
OpSchema().FillUsing(ReduceFunctionOp("L2 norm", EMPTY_ZERO, reduce_l2_func_body))); | |
std::function<void(OpSchema&)> ArgReduceDocGenerator(const char* name) { | |
return [=](OpSchema& schema) { | |
std::string doc; | |
POPULATE_OP_DOC_STR(doc = R"DOC( | |
Computes the indices of the {name} elements of the input tensor's element along the | |
provided axis. The resulting tensor has the same rank as the input if keepdims equals 1. | |
If keepdims equals 0, then the resulting tensor has the reduced dimension pruned. | |
If select_last_index is True (default False), the index of the last occurrence of the {name} | |
is selected if the {name} appears more than once in the input. Otherwise the index of the | |
first occurrence is selected. | |
The type of the output tensor is integer.)DOC"; | |
ReplaceAll(doc, "{name}", name);); | |
schema.SetDoc(doc.c_str()); | |
schema.Attr( | |
"axis", | |
"The axis in which to compute the arg indices. Accepted range is [-r, r-1] where r = rank(data).", | |
AttributeProto::INT, | |
static_cast<int64_t>(0)); | |
schema.Attr( | |
"keepdims", | |
"Keep the reduced dimension or not, default 1 means keep reduced dimension.", | |
AttributeProto::INT, | |
static_cast<int64_t>(1)); | |
schema.Attr( | |
"select_last_index", | |
"Whether to select the last index or the first index if the {name} appears in multiple indices, default is False (first index).", | |
AttributeProto::INT, | |
static_cast<int64_t>(0)); | |
schema.Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable); | |
schema.Output( | |
0, | |
"reduced", | |
"Reduced output tensor with integer data type.", | |
"tensor(int64)", | |
OpSchema::Single, | |
true, | |
1, | |
OpSchema::NonDifferentiable); | |
schema.TypeConstraint( | |
"T", OpSchema::all_numeric_types_ir4(), "Constrain input and output types to all numeric tensors."); | |
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) { | |
// set output element type to int64 | |
updateOutputElemType(ctx, 0, TensorProto_DataType_INT64); | |
if (!hasNInputShapes(ctx, 1)) { | |
return; | |
} | |
auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); | |
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); | |
int64_t input_ndim = input_shape.dim_size(); | |
int64_t axis = 0; // default to 0 | |
auto axis_proto = ctx.getAttribute("axis"); | |
if (axis_proto) { | |
axis = axis_proto->i(); | |
if (axis < -input_ndim || axis >= input_ndim) { | |
fail_shape_inference("'axis' must be in [-rank(indices), rank(indices)-1]"); | |
} | |
if (axis < 0) | |
axis += input_ndim; | |
} | |
int64_t keep_dims = 1; | |
auto attr_proto = ctx.getAttribute("keepdims"); | |
if (attr_proto) { | |
keep_dims = attr_proto->i(); | |
} | |
// do we need handle negative axis? | |
for (int i = 0; i < input_ndim; ++i) { | |
if (i != axis) { | |
auto dim = output_shape->add_dim(); | |
dim->CopyFrom(input_shape.dim(i)); | |
} else { | |
if (keep_dims == 1) { | |
auto dim = output_shape->add_dim(); | |
dim->set_dim_value(1); | |
} | |
} | |
} | |
}); | |
}; | |
} | |
ONNX_OPERATOR_SET_SCHEMA(ArgMax, 13, OpSchema().FillUsing(ArgReduceDocGenerator("max"))); | |
ONNX_OPERATOR_SET_SCHEMA(ArgMin, 13, OpSchema().FillUsing(ArgReduceDocGenerator("min"))); | |
} // namespace ONNX_NAMESPACE | |