Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
28.1 kB
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
// by this parser is preliminary and may change.
#include "onnx/defs/parser.h"
#include <cctype>
#include <iostream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "onnx/common/common.h"
#include "onnx/onnx_pb.h"
#include "onnx/string_utils.h"
#define PARSE_TOKEN(x) CHECK_PARSER_STATUS(ParserBase::Parse(x))
#define PARSE(...) CHECK_PARSER_STATUS(Parse(__VA_ARGS__))
#define MATCH(...) CHECK_PARSER_STATUS(Match(__VA_ARGS__))
namespace ONNX_NAMESPACE {
Status ParserBase::Parse(Literal& result) {
bool decimal_point = false;
auto nextch = NextChar();
auto from = next_;
if (nextch == '"') {
++next_;
bool has_escape = false;
while ((next_ < end_) && (*next_ != '"')) {
if (*next_ == '\\') {
has_escape = true;
++next_;
if (next_ >= end_)
return ParseError("Incomplete string literal.");
}
++next_;
}
if (next_ >= end_)
return ParseError("Incomplete string literal.");
++next_;
result.type = LiteralType::STRING_LITERAL;
if (has_escape) {
std::string& target = result.value;
target.clear();
target.reserve(next_ - from - 2); // upper bound
// *from is the starting quote. *(next_-1) is the ending quote.
// Copy what is in-between, except for the escape character
while (++from < next_ - 1) {
// Copy current char, if not escape, or next char otherwise.
target.push_back(*from != '\\' ? (*from) : *(++from));
}
} else
result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
return Status::OK();
}
// Simplify the next ifs by consuming a possible negative sign.
if (nextch == '-') {
++next_;
nextch = NextChar();
}
// Check for float literals that start with alphabet characters.
if (isalpha(nextch)) {
// Has to be a special float literal now: (-)*(nan|inf|infinity).
if (NextIsValidFloatString()) {
while (next_ < end_ && isalpha(*next_)) {
++next_;
}
ONNX_TRY {
static_cast<void>(std::stof(std::string(from, next_ - from)));
result.type = LiteralType::FLOAT_LITERAL;
result.value = std::string(from, next_ - from);
}
ONNX_CATCH(...) {
ONNX_HANDLE_EXCEPTION([&]() { return ParseError("Encountered invalid float literal!"); });
}
} else {
return ParseError("Encountered invalid float literal!");
}
return Status::OK();
}
// Checking for numeric ints or float literal.
if (isdigit(nextch)) {
++next_;
while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
if (*next_ == '.') {
if (decimal_point)
break; // Only one decimal point allowed in numeric literal
decimal_point = true;
}
++next_;
}
if (next_ == from)
return ParseError("Value expected but not found.");
// Optional exponent syntax: (e|E)(+|-)?[0-9]+
if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) {
decimal_point = true; // treat as float-literal
++next_;
if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-')))
++next_;
while ((next_ < end_) && (isdigit(*next_)))
++next_;
}
result.value = std::string(from, next_ - from);
result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
}
return Status::OK();
}
bool ParserBase::NextIsValidFloatString() {
auto nextch = NextChar();
auto from = next_;
constexpr int INFINITY_LENGTH = 8;
if (isalpha(nextch)) {
while (next_ < end_ && isalpha(*next_) && (next_ - from) <= INFINITY_LENGTH) {
++next_;
}
if (isdigit(*next_)) { // No trailing digits
next_ = from;
return false;
}
std::string candidate = std::string(from, next_ - from);
// Reset parser location before continuing.
next_ = from;
std::transform(
candidate.begin(), candidate.end(), candidate.begin(), [](unsigned char c) { return std::tolower(c); });
if (candidate == std::string("inf") || candidate == std::string("infinity") || candidate == std::string("nan")) {
return true;
}
}
return false;
}
Status OnnxParser::Parse(IdList& idlist) {
idlist.Clear();
std::string id;
ParseOptionalIdentifier(id);
if (id.empty())
return Status::OK(); // Treat as empty list of identifiers
*idlist.Add() = id;
while (Matches(',')) {
ParseOptionalIdentifier(id);
*idlist.Add() = id;
}
return Status::OK();
}
Status OnnxParser::Parse(char open, IdList& idlist, char close) {
idlist.Clear();
if (Matches(open)) {
PARSE(idlist);
MATCH(close);
}
return Status::OK();
}
Status OnnxParser::Parse(IdList& idlist, AttrList& attrlist) {
idlist.Clear();
attrlist.Clear();
do {
std::string id;
ParseIdentifier(id);
auto next = NextChar();
if (next == ':' || next == '=')
Parse(*attrlist.Add(), id);
else
*idlist.Add() = id;
} while (Matches(','));
return Status::OK();
}
Status OnnxParser::Parse(char open, IdList& idlist, AttrList& attrlist, char close) {
if (Matches(open)) {
PARSE(idlist, attrlist);
MATCH(close);
} else {
idlist.Clear();
attrlist.Clear();
}
return Status::OK();
}
Status OnnxParser::Parse(TensorShapeProto& shape) {
shape.clear_dim();
do {
if (Matches('?')) {
shape.add_dim();
} else {
// Check for a symbolic identifier ...
std::string id;
CHECK_PARSER_STATUS(ParseOptionalIdentifier(id));
if (!id.empty()) {
shape.add_dim()->set_dim_param(id);
} else {
// ...or a integer value
int64_t dimval = 0;
PARSE_TOKEN(dimval);
shape.add_dim()->set_dim_value(dimval);
}
}
} while (Matches(','));
return Status::OK();
}
Status OnnxParser::Parse(TypeProto& typeProto) {
std::string id;
CHECK_PARSER_STATUS(ParseIdentifier(id));
int dtype = PrimitiveTypeNameMap::Lookup(id);
if (dtype != 0) {
auto* tensortype = typeProto.mutable_tensor_type();
tensortype->set_elem_type(dtype);
tensortype->clear_shape();
// Grammar:
// float indicates scalar (rank 0)
// float [] indicates unknown rank tensor (not a zero rank tensor)
// float [one-or-more-dimensions] indicates tensor of known rank > 0.
if (Matches('[')) {
if (!Matches(']')) {
PARSE(*tensortype->mutable_shape());
MATCH(']');
}
} else {
// Create shape with zero dimensions for scalar
(void)(tensortype->mutable_shape());
}
} else {
switch (KeyWordMap::Lookup(id)) {
case KeyWordMap::KeyWord::SEQ_TYPE: {
// Grammar: seq ( type )
MATCH('(');
auto* seqtype = typeProto.mutable_sequence_type();
PARSE(*seqtype->mutable_elem_type());
MATCH(')');
break;
}
case KeyWordMap::KeyWord::MAP_TYPE: {
// Grammar: map ( prim-type , type )
MATCH('(');
auto* maptype = typeProto.mutable_map_type();
CHECK_PARSER_STATUS(ParseIdentifier(id));
dtype = PrimitiveTypeNameMap::Lookup(id);
if (dtype == 0) {
return ParseError("Expecting primitive type as map key type.");
}
maptype->set_key_type(dtype);
MATCH(',');
PARSE(*maptype->mutable_value_type());
MATCH(')');
break;
}
case KeyWordMap::KeyWord::OPTIONAL_TYPE: {
// Grammar: optional ( type )
MATCH('(');
auto* opttype = typeProto.mutable_optional_type();
PARSE(*opttype->mutable_elem_type());
MATCH(')');
break;
}
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: {
// Grammar: sparse_tensor ( tensor-type )
MATCH('(');
CHECK_PARSER_STATUS(ParseIdentifier(id));
dtype = PrimitiveTypeNameMap::Lookup(id);
if (dtype != 0) {
auto* sparsetype = typeProto.mutable_sparse_tensor_type();
sparsetype->set_elem_type(dtype);
sparsetype->clear_shape();
// Grammar:
// float indicates scalar (rank 0)
// float [] indicates unknown rank tensor (not a zero rank tensor)
// float [one-or-more-dimensions] indicates tensor of known rank > 0.
if (Matches('[')) {
if (!Matches(']')) {
PARSE(*sparsetype->mutable_shape());
MATCH(']');
}
} else {
// Create shape with zero dimensions for scalar
(void)(sparsetype->mutable_shape());
}
} else {
return ParseError("Unexpected type in sparse-tensor element type.");
}
MATCH(')');
break;
}
default:
return ParseError("Unexpected type.");
}
}
return Status::OK();
}
Status OnnxParser::Parse(ValueInfoProto& valueinfo) {
if (NextIsType())
PARSE(*valueinfo.mutable_type());
std::string name;
CHECK_PARSER_STATUS(ParseIdentifier(name));
valueinfo.set_name(name);
return Status::OK();
}
Status OnnxParser::Parse(char open, ValueInfoList& vilist, char close) {
MATCH(open);
if (!Matches(close)) {
do {
PARSE(*vilist.Add());
} while (Matches(','));
MATCH(close);
}
return Status::OK();
}
Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) {
vilist.Clear();
PARSE('(', vilist, ')');
return Status::OK();
}
Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) {
// Do not clear vilist, as it accumulates values over inputs and outputs.
idlist.Clear();
MATCH('(');
if (!Matches(')')) {
do {
// Function inputs/outputs can be optionally typed.
// Syntax: Name | Type Name
// The name is added to idlist. If the optional type is present, an entry is
// added to vilist.
std::string* name = idlist.Add();
ValueInfoProto* vi = nullptr;
if (NextIsType()) {
vi = vilist.Add();
PARSE(*(vi->mutable_type()));
}
CHECK_PARSER_STATUS(ParseIdentifier(*name));
if (vi != nullptr)
vi->set_name(*name);
} while (Matches(','));
MATCH(')');
}
return Status::OK();
}
// Each input element is a value-info with an optional initializer of the form "= initial-value".
// The value-info is added to the "inputs", while the initializer is added to initializers.
Status OnnxParser::ParseInput(ValueInfoList& inputs, TensorList& initializers) {
inputs.Clear();
if (Matches('(')) {
if (!Matches(')')) {
do {
ValueInfoProto vi;
PARSE(vi);
*inputs.Add() = vi;
if (Matches('=')) {
// default value for input
TensorProto& tp = *initializers.Add();
tp.set_name(vi.name());
CHECK_PARSER_STATUS(Parse(tp, vi.type()));
}
} while (Matches(','));
MATCH(')');
}
}
return Status::OK();
}
// This is handled slightly different from the inputs.
// Each element is either a value-info or an initializer.
// A value-info is added to the "value_infos", while an initializer is added to initializers.
Status OnnxParser::ParseValueInfo(ValueInfoList& value_infos, TensorList& initializers) {
value_infos.Clear();
if (Matches('<')) {
if (!Matches('>')) {
do {
ValueInfoProto vi;
PARSE(vi);
if (Matches('=')) {
// initializer
TensorProto& tp = *initializers.Add();
tp.set_name(vi.name());
CHECK_PARSER_STATUS(Parse(tp, vi.type()));
} else {
// valueinfo
*value_infos.Add() = vi;
}
} while (Matches(','));
MATCH('>');
}
}
return Status::OK();
}
Status OnnxParser::Parse(StringStringList& stringStringList) {
std::string strval;
do {
auto* metadata = stringStringList.Add();
PARSE_TOKEN(strval);
metadata->set_key(strval);
MATCH(':');
PARSE_TOKEN(strval);
metadata->set_value(strval);
} while (Matches(','));
return Status::OK();
}
Status OnnxParser::Parse(TensorProto& tensorProto) {
tensorProto = TensorProto();
// Parse the concrete tensor-type with numeric dimensions:
TypeProto typeProto;
PARSE(typeProto);
ParseOptionalIdentifier(*tensorProto.mutable_name());
(void)Matches('='); // Optional, to unify handling of initializers as well as tensor-protos in other contexts
return Parse(tensorProto, typeProto);
}
// Parse TensorProto data given its type:
Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto) {
if (!tensorTypeProto.has_tensor_type())
return ParseError("Error parsing TensorProto (expected a tensor type).");
auto elem_type = tensorTypeProto.tensor_type().elem_type();
tensorProto.set_data_type(elem_type);
if (!tensorTypeProto.tensor_type().has_shape())
return ParseError("Error parsing TensorProto (expected a tensor shape).");
for (auto& dim : tensorTypeProto.tensor_type().shape().dim()) {
if (!dim.has_dim_value())
return ParseError("Error parsing TensorProto shape (expected numeric dimension).");
auto dimval = dim.dim_value();
tensorProto.add_dims(dimval);
}
// tensorProto.mutable_int64_data()->Reserve(n);
// Parse the actual values:
int64_t intval;
uint64_t uintval = 0;
float floatval = 0.0;
double dblval = 0.0;
std::string strval;
if (Matches('{')) {
if (!Matches('}')) {
do {
switch (static_cast<TensorProto::DataType>(elem_type)) {
case TensorProto::DataType::TensorProto_DataType_INT8:
case TensorProto::DataType::TensorProto_DataType_INT16:
case TensorProto::DataType::TensorProto_DataType_INT32:
case TensorProto::DataType::TensorProto_DataType_UINT8:
case TensorProto::DataType::TensorProto_DataType_UINT16:
case TensorProto::DataType::TensorProto_DataType_BOOL:
PARSE_TOKEN(intval);
// TODO: check values are in the correct range.
tensorProto.add_int32_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
PARSE_TOKEN(intval);
tensorProto.add_int64_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_UINT32:
case TensorProto::DataType::TensorProto_DataType_UINT64:
PARSE_TOKEN(uintval);
tensorProto.add_uint64_data(uintval);
break;
case TensorProto::DataType::TensorProto_DataType_FLOAT:
PARSE_TOKEN(floatval);
tensorProto.add_float_data(floatval);
break;
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
PARSE_TOKEN(dblval);
tensorProto.add_double_data(dblval);
break;
case TensorProto::DataType::TensorProto_DataType_STRING:
PARSE_TOKEN(strval);
tensorProto.add_string_data(strval);
break;
default:
return ParseError("Unhandled type: %d", elem_type);
}
} while (Matches(','));
MATCH('}');
}
} else if (Matches('[')) {
tensorProto.set_data_location(TensorProto::DataLocation::TensorProto_DataLocation_EXTERNAL);
auto& externalData = *tensorProto.mutable_external_data();
PARSE(externalData);
MATCH(']');
}
return Status::OK();
}
bool OnnxParser::NextIsIdentifier() {
std::string id("");
(void)PeekIdentifier(id);
return !(id.empty());
}
bool OnnxParser::NextIsType() {
std::string id("");
(void)PeekIdentifier(id);
if (PrimitiveTypeNameMap::IsTypeName(id))
return true;
switch (KeyWordMap::Lookup(id)) {
case KeyWordMap::KeyWord::SEQ_TYPE:
case KeyWordMap::KeyWord::MAP_TYPE:
case KeyWordMap::KeyWord::OPTIONAL_TYPE:
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE:
return true;
default:
return false;
}
}
Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected) {
// Parse a single-value
auto next = NextChar();
if (isalpha(next) || next == '_') {
if (NextIsType()) {
TypeProto typeProto;
Parse(typeProto);
next = NextChar();
if ((next == '{') || (next == '=') || (NextIsIdentifier())) {
attr.set_type(AttributeProto_AttributeType_TENSOR);
auto& tensorProto = *attr.mutable_t();
ParseOptionalIdentifier(*tensorProto.mutable_name());
(void)Matches('='); // Optional, to unify handling of initializers
Parse(tensorProto, typeProto);
} else {
attr.set_type(AttributeProto_AttributeType_TYPE_PROTO);
attr.mutable_tp()->CopyFrom(typeProto);
}
} else {
if (NextIsValidFloatString()) {
Literal literal;
PARSE_TOKEN(literal);
attr.set_type(AttributeProto_AttributeType_FLOAT);
attr.set_f(static_cast<float>(std::stof(literal.value)));
} else {
attr.set_type(AttributeProto_AttributeType_GRAPH);
PARSE(*attr.mutable_g());
}
}
} else if (Matches('@')) {
std::string name;
CHECK_PARSER_STATUS(ParseIdentifier(name));
attr.set_ref_attr_name(name);
} else {
Literal literal;
PARSE_TOKEN(literal);
switch (literal.type) {
case LiteralType::INT_LITERAL:
attr.set_type(AttributeProto_AttributeType_INT);
attr.set_i(std::stol(literal.value));
break;
case LiteralType::FLOAT_LITERAL:
attr.set_type(AttributeProto_AttributeType_FLOAT);
attr.set_f(static_cast<float>(std::stof(literal.value)));
break;
case LiteralType::STRING_LITERAL:
attr.set_type(AttributeProto_AttributeType_STRING);
attr.set_s(literal.value);
break;
}
}
if ((expected != AttributeProto_AttributeType_UNDEFINED) && (expected != attr.type())) {
// Mismatch between type-annotation and attribute-value. We do an implicit cast
// only in the special case of FLOAT type and integral value like 2
if ((expected == AttributeProto_AttributeType_FLOAT) && (attr.type() == AttributeProto_AttributeType_INT)) {
attr.set_type(AttributeProto_AttributeType_FLOAT);
attr.set_f(static_cast<float>(attr.i()));
} else {
return ParseError(
"Mismatch between expected type ",
AttributeProto_AttributeType_Name(expected),
" and specified value's type",
AttributeProto_AttributeType_Name(attr.type()));
}
}
return Status::OK();
}
Status OnnxParser::Parse(AttributeProto& attr) {
attr.Clear();
std::string name;
CHECK_PARSER_STATUS(ParseIdentifier(name));
return Parse(attr, name);
}
bool IsSingletonAttribute(AttributeProto_AttributeType type) {
switch (type) {
case AttributeProto_AttributeType_FLOAT:
case AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType_STRING:
case AttributeProto_AttributeType_TENSOR:
case AttributeProto_AttributeType_GRAPH:
case AttributeProto_AttributeType_SPARSE_TENSOR:
case AttributeProto_AttributeType_TYPE_PROTO:
return true;
default:
return false;
}
}
AttributeProto_AttributeType ToSingletonType(AttributeProto_AttributeType type) {
switch (type) {
case AttributeProto_AttributeType_FLOATS:
return AttributeProto_AttributeType_FLOAT;
case AttributeProto_AttributeType_INTS:
return AttributeProto_AttributeType_INT;
case AttributeProto_AttributeType_STRINGS:
return AttributeProto_AttributeType_STRING;
case AttributeProto_AttributeType_TENSORS:
return AttributeProto_AttributeType_TENSOR;
case AttributeProto_AttributeType_GRAPHS:
return AttributeProto_AttributeType_GRAPH;
case AttributeProto_AttributeType_SPARSE_TENSORS:
return AttributeProto_AttributeType_SPARSE_TENSOR;
case AttributeProto_AttributeType_TYPE_PROTOS:
return AttributeProto_AttributeType_TYPE_PROTO;
default:
return type;
}
}
Status OnnxParser::Parse(AttributeProto& attr, std::string& name) {
attr.set_name(name);
if (Matches(':')) {
CHECK_PARSER_STATUS(ParseIdentifier(name));
int attrtype = AttributeTypeNameMap::Lookup(name);
if (attrtype != 0) {
attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype));
} else {
return ParseError("Unexpected attribute type.");
}
}
MATCH('=');
if (NextChar() == '[') {
// Parse a list of values. For an empty list, the type MUST be specified
// using the type-annotation syntax of ": type".
std::vector<Literal> vals;
MATCH('[');
if (NextChar() != ']') {
do {
AttributeProto nextval;
auto expected_type = ToSingletonType(attr.type());
CHECK_PARSER_STATUS(ParseSingleAttributeValue(nextval, expected_type));
switch (nextval.type()) {
case AttributeProto_AttributeType_INT:
attr.set_type(AttributeProto_AttributeType_INTS);
attr.add_ints(nextval.i());
break;
case AttributeProto_AttributeType_FLOAT:
attr.set_type(AttributeProto_AttributeType_FLOATS);
attr.add_floats(nextval.f());
break;
case AttributeProto_AttributeType_STRING:
attr.add_strings(nextval.s());
attr.set_type(AttributeProto_AttributeType_STRINGS);
break;
default:
break;
}
} while (Matches(','));
} else {
if (attr.type() == AttributeProto_AttributeType_UNDEFINED)
return ParseError("Empty list attribute value requires type annotation.");
if (IsSingletonAttribute(attr.type()))
return ParseError("Singleton attribute value cannot be specified as a list.");
}
MATCH(']');
} else {
CHECK_PARSER_STATUS(ParseSingleAttributeValue(attr, attr.type()));
}
return Status::OK();
}
Status OnnxParser::Parse(AttrList& attrlist) {
attrlist.Clear();
if (Matches('<')) {
do {
PARSE(*attrlist.Add());
} while (Matches(','));
MATCH('>');
}
return Status::OK();
}
Status OnnxParser::Parse(NodeProto& node) {
PARSE(*node.mutable_output());
MATCH('=');
std::string domain("");
std::string id;
ParseIdentifier(id);
while (Matches('.')) {
if (!domain.empty())
domain += ".";
domain += id;
ParseIdentifier(id);
}
node.set_domain(domain);
node.set_op_type(id);
if (Matches(':')) {
std::string overload;
ParseIdentifier(overload);
node.set_overload(overload);
}
PARSE(*node.mutable_attribute());
MATCH('(');
PARSE(*node.mutable_input());
MATCH(')');
if (node.attribute_size() == 0) {
// Permit attributes to be specified before or after parameters.
PARSE(*node.mutable_attribute());
}
return Status::OK();
}
Status OnnxParser::Parse(NodeList& nodelist) {
nodelist.Clear();
MATCH('{');
while (!Matches('}')) {
PARSE(*nodelist.Add());
}
return Status::OK();
}
Status OnnxParser::Parse(GraphProto& graph) {
std::string id;
ParseIdentifier(id);
return Parse(id, graph);
}
Status OnnxParser::Parse(std::string name, GraphProto& graph) {
graph.set_name(name);
graph.mutable_initializer()->Clear();
CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer()));
MATCH('=');
MATCH('>', false);
CHECK_PARSER_STATUS(ParseGraphInputOutput(*graph.mutable_output()));
CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer()));
return Parse(*graph.mutable_node());
}
Status OnnxParser::Parse(FunctionProto& fn) {
fn.Clear();
std::string strval;
if (Matches('<')) {
do {
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
PARSE_TOKEN(keyword);
MATCH(':');
switch (keyword) {
case KeyWordMap::KeyWord::OPSET_IMPORT:
PARSE(*fn.mutable_opset_import());
break;
case KeyWordMap::KeyWord::DOC_STRING:
PARSE_TOKEN(strval);
fn.set_doc_string(strval);
break;
case KeyWordMap::KeyWord::DOMAIN_KW:
PARSE_TOKEN(strval);
fn.set_domain(strval);
break;
case KeyWordMap::KeyWord::OVERLOAD_KW:
PARSE_TOKEN(strval);
fn.set_overload(strval);
break;
default:
return ParseError("Unhandled keyword.");
}
} while (Matches(','));
MATCH('>');
}
std::string id;
ParseIdentifier(id);
fn.set_name(id);
PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>');
fn.mutable_value_info()->Clear();
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.mutable_value_info()));
MATCH('=');
MATCH('>', false);
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.mutable_value_info()));
if (NextChar() == '<') {
PARSE('<', *fn.mutable_value_info(), '>');
}
return Parse(*fn.mutable_node());
}
Status OnnxParser::Parse(OpsetIdList& opsets) {
std::string strval;
int64_t intval = 0;
MATCH('[');
if (!Matches(']')) {
do {
auto* import = opsets.Add();
PARSE_TOKEN(strval);
import->set_domain(strval);
MATCH(':');
PARSE_TOKEN(intval);
import->set_version(intval);
} while (Matches(','));
MATCH(']');
}
return Status::OK();
}
Status OnnxParser::Parse(ModelProto& model) {
model.Clear();
std::string strval;
int64_t intval;
if (Matches('<')) {
do {
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
PARSE_TOKEN(keyword);
MATCH(':');
switch (keyword) {
case KeyWordMap::KeyWord::IR_VERSION:
PARSE_TOKEN(intval);
model.set_ir_version(intval);
break;
case KeyWordMap::KeyWord::OPSET_IMPORT:
PARSE(*model.mutable_opset_import());
break;
case KeyWordMap::KeyWord::PRODUCER_NAME:
PARSE_TOKEN(strval);
model.set_producer_name(strval);
break;
case KeyWordMap::KeyWord::PRODUCER_VERSION:
PARSE_TOKEN(strval);
model.set_producer_version(strval);
break;
case KeyWordMap::KeyWord::DOMAIN_KW:
PARSE_TOKEN(strval);
model.set_domain(strval);
break;
case KeyWordMap::KeyWord::MODEL_VERSION:
PARSE_TOKEN(intval);
model.set_model_version(intval);
break;
case KeyWordMap::KeyWord::DOC_STRING:
PARSE_TOKEN(strval);
model.set_doc_string(strval);
break;
case KeyWordMap::KeyWord::METADATA_PROPS: {
auto& metadata_props = *model.mutable_metadata_props();
MATCH('[');
if (!Matches(']')) {
PARSE(metadata_props);
MATCH(']');
}
break;
}
default:
return ParseError("Unhandled keyword.");
}
} while (Matches(','));
MATCH('>');
}
PARSE(*model.mutable_graph());
auto* functions = model.mutable_functions();
while (!EndOfInput()) {
PARSE(*functions->Add());
}
return Status::OK();
}
} // namespace ONNX_NAMESPACE