Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
14 kB
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/printer.h"
#include <iomanip>
#include <vector>
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
using StringStringEntryProtos = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
class ProtoPrinter {
public:
ProtoPrinter(std::ostream& os) : output_(os) {}
void print(const TensorShapeProto_Dimension& dim);
void print(const TensorShapeProto& shape);
void print(const TypeProto_Tensor& tensortype);
void print(const TypeProto& type);
void print(const TypeProto_Sequence& seqType);
void print(const TypeProto_Map& mapType);
void print(const TypeProto_Optional& optType);
void print(const TypeProto_SparseTensor& sparseType);
void print(const TensorProto& tensor, bool is_initializer = false);
void print(const ValueInfoProto& value_info);
void print(const ValueInfoList& vilist);
void print(const AttributeProto& attr);
void print(const AttrList& attrlist);
void print(const NodeProto& node);
void print(const NodeList& nodelist);
void print(const GraphProto& graph);
void print(const FunctionProto& fn);
void print(const ModelProto& model);
void print(const OperatorSetIdProto& opset);
void print(const OpsetIdList& opsets);
void print(const StringStringEntryProtos& stringStringProtos) {
printSet("[", ", ", "]", stringStringProtos);
}
void print(const StringStringEntryProto& metadata) {
printQuoted(metadata.key());
output_ << ": ";
printQuoted(metadata.value());
}
private:
template <typename T>
inline void print(T prim) {
output_ << prim;
}
void printQuoted(const std::string& str) {
output_ << "\"";
for (const char* p = str.c_str(); *p; ++p) {
if ((*p == '\\') || (*p == '"'))
output_ << '\\';
output_ << *p;
}
output_ << "\"";
}
template <typename T>
inline void printKeyValuePair(KeyWordMap::KeyWord key, const T& val, bool addsep = true) {
if (addsep)
output_ << "," << std::endl;
output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
print(val);
}
inline void printKeyValuePair(KeyWordMap::KeyWord key, const std::string& val) {
output_ << "," << std::endl;
output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
printQuoted(val);
}
template <typename Collection>
inline void printSet(const char* open, const char* separator, const char* close, Collection coll) {
const char* sep = "";
output_ << open;
for (auto& elt : coll) {
output_ << sep;
print(elt);
sep = separator;
}
output_ << close;
}
std::ostream& output_;
int indent_level = 3;
void indent() {
indent_level += 3;
}
void outdent() {
indent_level -= 3;
}
};
void ProtoPrinter::print(const TensorShapeProto_Dimension& dim) {
if (dim.has_dim_value())
output_ << dim.dim_value();
else if (dim.has_dim_param())
output_ << dim.dim_param();
else
output_ << "?";
}
void ProtoPrinter::print(const TensorShapeProto& shape) {
printSet("[", ",", "]", shape.dim());
}
void ProtoPrinter::print(const TypeProto_Tensor& tensortype) {
output_ << PrimitiveTypeNameMap::ToString(tensortype.elem_type());
if (tensortype.has_shape()) {
if (tensortype.shape().dim_size() > 0)
print(tensortype.shape());
} else
output_ << "[]";
}
void ProtoPrinter::print(const TypeProto_Sequence& seqType) {
output_ << "seq(";
print(seqType.elem_type());
output_ << ")";
}
void ProtoPrinter::print(const TypeProto_Map& mapType) {
output_ << "map(" << PrimitiveTypeNameMap::ToString(mapType.key_type()) << ", ";
print(mapType.value_type());
output_ << ")";
}
void ProtoPrinter::print(const TypeProto_Optional& optType) {
output_ << "optional(";
print(optType.elem_type());
output_ << ")";
}
void ProtoPrinter::print(const TypeProto_SparseTensor& sparseType) {
output_ << "sparse_tensor(" << PrimitiveTypeNameMap::ToString(sparseType.elem_type());
if (sparseType.has_shape()) {
if (sparseType.shape().dim_size() > 0)
print(sparseType.shape());
} else
output_ << "[]";
output_ << ")";
}
void ProtoPrinter::print(const TypeProto& type) {
if (type.has_tensor_type())
print(type.tensor_type());
else if (type.has_sequence_type())
print(type.sequence_type());
else if (type.has_map_type())
print(type.map_type());
else if (type.has_optional_type())
print(type.optional_type());
else if (type.has_sparse_tensor_type())
print(type.sparse_tensor_type());
}
void ProtoPrinter::print(const TensorProto& tensor, bool is_initializer) {
output_ << PrimitiveTypeNameMap::ToString(tensor.data_type());
if (tensor.dims_size() > 0)
printSet("[", ",", "]", tensor.dims());
if (!tensor.name().empty()) {
output_ << " " << tensor.name();
}
if (is_initializer) {
output_ << " = ";
}
// TODO: does not yet handle all types
if (tensor.has_data_location() && tensor.data_location() == TensorProto_DataLocation_EXTERNAL) {
print(tensor.external_data());
} else if (tensor.has_raw_data()) {
switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
case TensorProto::DataType::TensorProto_DataType_INT32:
printSet(" {", ",", "}", ParseData<int32_t>(&tensor));
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
printSet(" {", ",", "}", ParseData<int64_t>(&tensor));
break;
case TensorProto::DataType::TensorProto_DataType_FLOAT:
printSet(" {", ",", "}", ParseData<float>(&tensor));
break;
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
printSet(" {", ",", "}", ParseData<double>(&tensor));
break;
default:
output_ << "..."; // ParseData not instantiated for other types.
break;
}
} else {
switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
case TensorProto::DataType::TensorProto_DataType_INT8:
case TensorProto::DataType::TensorProto_DataType_INT16:
case TensorProto::DataType::TensorProto_DataType_INT32:
case TensorProto::DataType::TensorProto_DataType_UINT8:
case TensorProto::DataType::TensorProto_DataType_UINT16:
case TensorProto::DataType::TensorProto_DataType_BOOL:
printSet(" {", ",", "}", tensor.int32_data());
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
printSet(" {", ",", "}", tensor.int64_data());
break;
case TensorProto::DataType::TensorProto_DataType_UINT32:
case TensorProto::DataType::TensorProto_DataType_UINT64:
printSet(" {", ",", "}", tensor.uint64_data());
break;
case TensorProto::DataType::TensorProto_DataType_FLOAT:
printSet(" {", ",", "}", tensor.float_data());
break;
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
printSet(" {", ",", "}", tensor.double_data());
break;
case TensorProto::DataType::TensorProto_DataType_STRING: {
const char* sep = "{";
for (auto& elt : tensor.string_data()) {
output_ << sep;
printQuoted(elt);
sep = ", ";
}
output_ << "}";
break;
}
default:
break;
}
}
}
void ProtoPrinter::print(const ValueInfoProto& value_info) {
print(value_info.type());
output_ << " " << value_info.name();
}
void ProtoPrinter::print(const ValueInfoList& vilist) {
printSet("(", ", ", ")", vilist);
}
void ProtoPrinter::print(const AttributeProto& attr) {
// Special case of attr-ref:
if (attr.has_ref_attr_name()) {
output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = @" << attr.ref_attr_name();
return;
}
// General case:
output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = ";
switch (attr.type()) {
case AttributeProto_AttributeType_INT:
output_ << attr.i();
break;
case AttributeProto_AttributeType_INTS:
printSet("[", ", ", "]", attr.ints());
break;
case AttributeProto_AttributeType_FLOAT:
output_ << attr.f();
break;
case AttributeProto_AttributeType_FLOATS:
printSet("[", ", ", "]", attr.floats());
break;
case AttributeProto_AttributeType_STRING:
output_ << "\"" << attr.s() << "\"";
break;
case AttributeProto_AttributeType_STRINGS: {
const char* sep = "[";
for (auto& elt : attr.strings()) {
output_ << sep << "\"" << elt << "\"";
sep = ", ";
}
output_ << "]";
break;
}
case AttributeProto_AttributeType_GRAPH:
indent();
print(attr.g());
outdent();
break;
case AttributeProto_AttributeType_GRAPHS:
indent();
printSet("[", ", ", "]", attr.graphs());
outdent();
break;
case AttributeProto_AttributeType_TENSOR:
print(attr.t());
break;
case AttributeProto_AttributeType_TENSORS:
printSet("[", ", ", "]", attr.tensors());
break;
case AttributeProto_AttributeType_TYPE_PROTO:
print(attr.tp());
break;
case AttributeProto_AttributeType_TYPE_PROTOS:
printSet("[", ", ", "]", attr.type_protos());
break;
default:
break;
}
}
void ProtoPrinter::print(const AttrList& attrlist) {
printSet(" <", ", ", ">", attrlist);
}
void ProtoPrinter::print(const NodeProto& node) {
output_ << std::setw(indent_level) << ' ';
printSet("", ", ", "", node.output());
output_ << " = ";
if (node.domain() != "")
output_ << node.domain() << ".";
output_ << node.op_type();
if (node.overload() != "")
output_ << ":" << node.overload();
bool has_subgraph = false;
for (auto attr : node.attribute())
if (attr.has_g() || (attr.graphs_size() > 0))
has_subgraph = true;
if ((!has_subgraph) && (node.attribute_size() > 0))
print(node.attribute());
printSet(" (", ", ", ")", node.input());
if ((has_subgraph) && (node.attribute_size() > 0))
print(node.attribute());
output_ << "\n";
}
void ProtoPrinter::print(const NodeList& nodelist) {
output_ << "{\n";
for (auto& node : nodelist) {
print(node);
}
if (indent_level > 3)
output_ << std::setw(indent_level - 3) << " ";
output_ << "}";
}
void ProtoPrinter::print(const GraphProto& graph) {
output_ << graph.name() << " " << graph.input() << " => " << graph.output() << " ";
if ((graph.initializer_size() > 0) || (graph.value_info_size() > 0)) {
output_ << std::endl << std::setw(indent_level) << ' ' << '<';
const char* sep = "";
for (auto& init : graph.initializer()) {
output_ << sep;
print(init, true);
sep = ", ";
}
for (auto& vi : graph.value_info()) {
output_ << sep;
print(vi);
sep = ", ";
}
output_ << ">" << std::endl;
}
print(graph.node());
}
void ProtoPrinter::print(const ModelProto& model) {
output_ << "<\n";
printKeyValuePair(KeyWordMap::KeyWord::IR_VERSION, model.ir_version(), false);
printKeyValuePair(KeyWordMap::KeyWord::OPSET_IMPORT, model.opset_import());
if (model.has_producer_name())
printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_NAME, model.producer_name());
if (model.has_producer_version())
printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_VERSION, model.producer_version());
if (model.has_domain())
printKeyValuePair(KeyWordMap::KeyWord::DOMAIN_KW, model.domain());
if (model.has_model_version())
printKeyValuePair(KeyWordMap::KeyWord::MODEL_VERSION, model.model_version());
if (model.has_doc_string())
printKeyValuePair(KeyWordMap::KeyWord::DOC_STRING, model.doc_string());
if (model.metadata_props_size() > 0)
printKeyValuePair(KeyWordMap::KeyWord::METADATA_PROPS, model.metadata_props());
output_ << std::endl << ">" << std::endl;
print(model.graph());
for (const auto& fn : model.functions()) {
output_ << std::endl;
print(fn);
}
}
void ProtoPrinter::print(const OperatorSetIdProto& opset) {
output_ << "\"" << opset.domain() << "\" : " << opset.version();
}
void ProtoPrinter::print(const OpsetIdList& opsets) {
printSet("[", ", ", "]", opsets);
}
void ProtoPrinter::print(const FunctionProto& fn) {
output_ << "<\n";
output_ << " "
<< "domain: \"" << fn.domain() << "\",\n";
if (!fn.overload().empty())
output_ << " "
<< "overload: \"" << fn.overload() << "\",\n";
output_ << " "
<< "opset_import: ";
printSet("[", ",", "]", fn.opset_import());
output_ << "\n>\n";
output_ << fn.name() << " ";
if (fn.attribute_size() > 0)
printSet("<", ",", ">", fn.attribute());
printSet("(", ", ", ")", fn.input());
output_ << " => ";
printSet("(", ", ", ")", fn.output());
output_ << "\n";
print(fn.node());
}
#define DEF_OP(T) \
std::ostream& operator<<(std::ostream& os, const T& proto) { \
ProtoPrinter printer(os); \
printer.print(proto); \
return os; \
};
DEF_OP(TensorShapeProto_Dimension)
DEF_OP(TensorShapeProto)
DEF_OP(TypeProto_Tensor)
DEF_OP(TypeProto)
DEF_OP(TensorProto)
DEF_OP(ValueInfoProto)
DEF_OP(ValueInfoList)
DEF_OP(AttributeProto)
DEF_OP(AttrList)
DEF_OP(NodeProto)
DEF_OP(NodeList)
DEF_OP(GraphProto)
DEF_OP(FunctionProto)
DEF_OP(ModelProto)
} // namespace ONNX_NAMESPACE