Spaces:
Running
Running
/* | |
* SPDX-License-Identifier: Apache-2.0 | |
*/ | |
// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized | |
// by this parser is preliminary and may change. | |
namespace ONNX_NAMESPACE { | |
Status ParserBase::Parse(Literal& result) { | |
bool decimal_point = false; | |
auto nextch = NextChar(); | |
auto from = next_; | |
if (nextch == '"') { | |
++next_; | |
bool has_escape = false; | |
while ((next_ < end_) && (*next_ != '"')) { | |
if (*next_ == '\\') { | |
has_escape = true; | |
++next_; | |
if (next_ >= end_) | |
return ParseError("Incomplete string literal."); | |
} | |
++next_; | |
} | |
if (next_ >= end_) | |
return ParseError("Incomplete string literal."); | |
++next_; | |
result.type = LiteralType::STRING_LITERAL; | |
if (has_escape) { | |
std::string& target = result.value; | |
target.clear(); | |
target.reserve(next_ - from - 2); // upper bound | |
// *from is the starting quote. *(next_-1) is the ending quote. | |
// Copy what is in-between, except for the escape character | |
while (++from < next_ - 1) { | |
// Copy current char, if not escape, or next char otherwise. | |
target.push_back(*from != '\\' ? (*from) : *(++from)); | |
} | |
} else | |
result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes | |
return Status::OK(); | |
} | |
// Simplify the next ifs by consuming a possible negative sign. | |
if (nextch == '-') { | |
++next_; | |
nextch = NextChar(); | |
} | |
// Check for float literals that start with alphabet characters. | |
if (isalpha(nextch)) { | |
// Has to be a special float literal now: (-)*(nan|inf|infinity). | |
if (NextIsValidFloatString()) { | |
while (next_ < end_ && isalpha(*next_)) { | |
++next_; | |
} | |
ONNX_TRY { | |
static_cast<void>(std::stof(std::string(from, next_ - from))); | |
result.type = LiteralType::FLOAT_LITERAL; | |
result.value = std::string(from, next_ - from); | |
} | |
ONNX_CATCH(...) { | |
ONNX_HANDLE_EXCEPTION([&]() { return ParseError("Encountered invalid float literal!"); }); | |
} | |
} else { | |
return ParseError("Encountered invalid float literal!"); | |
} | |
return Status::OK(); | |
} | |
// Checking for numeric ints or float literal. | |
if (isdigit(nextch)) { | |
++next_; | |
while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) { | |
if (*next_ == '.') { | |
if (decimal_point) | |
break; // Only one decimal point allowed in numeric literal | |
decimal_point = true; | |
} | |
++next_; | |
} | |
if (next_ == from) | |
return ParseError("Value expected but not found."); | |
// Optional exponent syntax: (e|E)(+|-)?[0-9]+ | |
if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) { | |
decimal_point = true; // treat as float-literal | |
++next_; | |
if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-'))) | |
++next_; | |
while ((next_ < end_) && (isdigit(*next_))) | |
++next_; | |
} | |
result.value = std::string(from, next_ - from); | |
result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL; | |
} | |
return Status::OK(); | |
} | |
bool ParserBase::NextIsValidFloatString() { | |
auto nextch = NextChar(); | |
auto from = next_; | |
constexpr int INFINITY_LENGTH = 8; | |
if (isalpha(nextch)) { | |
while (next_ < end_ && isalpha(*next_) && (next_ - from) <= INFINITY_LENGTH) { | |
++next_; | |
} | |
if (isdigit(*next_)) { // No trailing digits | |
next_ = from; | |
return false; | |
} | |
std::string candidate = std::string(from, next_ - from); | |
// Reset parser location before continuing. | |
next_ = from; | |
std::transform( | |
candidate.begin(), candidate.end(), candidate.begin(), [](unsigned char c) { return std::tolower(c); }); | |
if (candidate == std::string("inf") || candidate == std::string("infinity") || candidate == std::string("nan")) { | |
return true; | |
} | |
} | |
return false; | |
} | |
Status OnnxParser::Parse(IdList& idlist) { | |
idlist.Clear(); | |
std::string id; | |
ParseOptionalIdentifier(id); | |
if (id.empty()) | |
return Status::OK(); // Treat as empty list of identifiers | |
*idlist.Add() = id; | |
while (Matches(',')) { | |
ParseOptionalIdentifier(id); | |
*idlist.Add() = id; | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(char open, IdList& idlist, char close) { | |
idlist.Clear(); | |
if (Matches(open)) { | |
PARSE(idlist); | |
MATCH(close); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(IdList& idlist, AttrList& attrlist) { | |
idlist.Clear(); | |
attrlist.Clear(); | |
do { | |
std::string id; | |
ParseIdentifier(id); | |
auto next = NextChar(); | |
if (next == ':' || next == '=') | |
Parse(*attrlist.Add(), id); | |
else | |
*idlist.Add() = id; | |
} while (Matches(',')); | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(char open, IdList& idlist, AttrList& attrlist, char close) { | |
if (Matches(open)) { | |
PARSE(idlist, attrlist); | |
MATCH(close); | |
} else { | |
idlist.Clear(); | |
attrlist.Clear(); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(TensorShapeProto& shape) { | |
shape.clear_dim(); | |
do { | |
if (Matches('?')) { | |
shape.add_dim(); | |
} else { | |
// Check for a symbolic identifier ... | |
std::string id; | |
CHECK_PARSER_STATUS(ParseOptionalIdentifier(id)); | |
if (!id.empty()) { | |
shape.add_dim()->set_dim_param(id); | |
} else { | |
// ...or a integer value | |
int64_t dimval = 0; | |
PARSE_TOKEN(dimval); | |
shape.add_dim()->set_dim_value(dimval); | |
} | |
} | |
} while (Matches(',')); | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(TypeProto& typeProto) { | |
std::string id; | |
CHECK_PARSER_STATUS(ParseIdentifier(id)); | |
int dtype = PrimitiveTypeNameMap::Lookup(id); | |
if (dtype != 0) { | |
auto* tensortype = typeProto.mutable_tensor_type(); | |
tensortype->set_elem_type(dtype); | |
tensortype->clear_shape(); | |
// Grammar: | |
// float indicates scalar (rank 0) | |
// float [] indicates unknown rank tensor (not a zero rank tensor) | |
// float [one-or-more-dimensions] indicates tensor of known rank > 0. | |
if (Matches('[')) { | |
if (!Matches(']')) { | |
PARSE(*tensortype->mutable_shape()); | |
MATCH(']'); | |
} | |
} else { | |
// Create shape with zero dimensions for scalar | |
(void)(tensortype->mutable_shape()); | |
} | |
} else { | |
switch (KeyWordMap::Lookup(id)) { | |
case KeyWordMap::KeyWord::SEQ_TYPE: { | |
// Grammar: seq ( type ) | |
MATCH('('); | |
auto* seqtype = typeProto.mutable_sequence_type(); | |
PARSE(*seqtype->mutable_elem_type()); | |
MATCH(')'); | |
break; | |
} | |
case KeyWordMap::KeyWord::MAP_TYPE: { | |
// Grammar: map ( prim-type , type ) | |
MATCH('('); | |
auto* maptype = typeProto.mutable_map_type(); | |
CHECK_PARSER_STATUS(ParseIdentifier(id)); | |
dtype = PrimitiveTypeNameMap::Lookup(id); | |
if (dtype == 0) { | |
return ParseError("Expecting primitive type as map key type."); | |
} | |
maptype->set_key_type(dtype); | |
MATCH(','); | |
PARSE(*maptype->mutable_value_type()); | |
MATCH(')'); | |
break; | |
} | |
case KeyWordMap::KeyWord::OPTIONAL_TYPE: { | |
// Grammar: optional ( type ) | |
MATCH('('); | |
auto* opttype = typeProto.mutable_optional_type(); | |
PARSE(*opttype->mutable_elem_type()); | |
MATCH(')'); | |
break; | |
} | |
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: { | |
// Grammar: sparse_tensor ( tensor-type ) | |
MATCH('('); | |
CHECK_PARSER_STATUS(ParseIdentifier(id)); | |
dtype = PrimitiveTypeNameMap::Lookup(id); | |
if (dtype != 0) { | |
auto* sparsetype = typeProto.mutable_sparse_tensor_type(); | |
sparsetype->set_elem_type(dtype); | |
sparsetype->clear_shape(); | |
// Grammar: | |
// float indicates scalar (rank 0) | |
// float [] indicates unknown rank tensor (not a zero rank tensor) | |
// float [one-or-more-dimensions] indicates tensor of known rank > 0. | |
if (Matches('[')) { | |
if (!Matches(']')) { | |
PARSE(*sparsetype->mutable_shape()); | |
MATCH(']'); | |
} | |
} else { | |
// Create shape with zero dimensions for scalar | |
(void)(sparsetype->mutable_shape()); | |
} | |
} else { | |
return ParseError("Unexpected type in sparse-tensor element type."); | |
} | |
MATCH(')'); | |
break; | |
} | |
default: | |
return ParseError("Unexpected type."); | |
} | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(ValueInfoProto& valueinfo) { | |
if (NextIsType()) | |
PARSE(*valueinfo.mutable_type()); | |
std::string name; | |
CHECK_PARSER_STATUS(ParseIdentifier(name)); | |
valueinfo.set_name(name); | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(char open, ValueInfoList& vilist, char close) { | |
MATCH(open); | |
if (!Matches(close)) { | |
do { | |
PARSE(*vilist.Add()); | |
} while (Matches(',')); | |
MATCH(close); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) { | |
vilist.Clear(); | |
PARSE('(', vilist, ')'); | |
return Status::OK(); | |
} | |
Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) { | |
// Do not clear vilist, as it accumulates values over inputs and outputs. | |
idlist.Clear(); | |
MATCH('('); | |
if (!Matches(')')) { | |
do { | |
// Function inputs/outputs can be optionally typed. | |
// Syntax: Name | Type Name | |
// The name is added to idlist. If the optional type is present, an entry is | |
// added to vilist. | |
std::string* name = idlist.Add(); | |
ValueInfoProto* vi = nullptr; | |
if (NextIsType()) { | |
vi = vilist.Add(); | |
PARSE(*(vi->mutable_type())); | |
} | |
CHECK_PARSER_STATUS(ParseIdentifier(*name)); | |
if (vi != nullptr) | |
vi->set_name(*name); | |
} while (Matches(',')); | |
MATCH(')'); | |
} | |
return Status::OK(); | |
} | |
// Each input element is a value-info with an optional initializer of the form "= initial-value". | |
// The value-info is added to the "inputs", while the initializer is added to initializers. | |
Status OnnxParser::ParseInput(ValueInfoList& inputs, TensorList& initializers) { | |
inputs.Clear(); | |
if (Matches('(')) { | |
if (!Matches(')')) { | |
do { | |
ValueInfoProto vi; | |
PARSE(vi); | |
*inputs.Add() = vi; | |
if (Matches('=')) { | |
// default value for input | |
TensorProto& tp = *initializers.Add(); | |
tp.set_name(vi.name()); | |
CHECK_PARSER_STATUS(Parse(tp, vi.type())); | |
} | |
} while (Matches(',')); | |
MATCH(')'); | |
} | |
} | |
return Status::OK(); | |
} | |
// This is handled slightly different from the inputs. | |
// Each element is either a value-info or an initializer. | |
// A value-info is added to the "value_infos", while an initializer is added to initializers. | |
Status OnnxParser::ParseValueInfo(ValueInfoList& value_infos, TensorList& initializers) { | |
value_infos.Clear(); | |
if (Matches('<')) { | |
if (!Matches('>')) { | |
do { | |
ValueInfoProto vi; | |
PARSE(vi); | |
if (Matches('=')) { | |
// initializer | |
TensorProto& tp = *initializers.Add(); | |
tp.set_name(vi.name()); | |
CHECK_PARSER_STATUS(Parse(tp, vi.type())); | |
} else { | |
// valueinfo | |
*value_infos.Add() = vi; | |
} | |
} while (Matches(',')); | |
MATCH('>'); | |
} | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(StringStringList& stringStringList) { | |
std::string strval; | |
do { | |
auto* metadata = stringStringList.Add(); | |
PARSE_TOKEN(strval); | |
metadata->set_key(strval); | |
MATCH(':'); | |
PARSE_TOKEN(strval); | |
metadata->set_value(strval); | |
} while (Matches(',')); | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(TensorProto& tensorProto) { | |
tensorProto = TensorProto(); | |
// Parse the concrete tensor-type with numeric dimensions: | |
TypeProto typeProto; | |
PARSE(typeProto); | |
ParseOptionalIdentifier(*tensorProto.mutable_name()); | |
(void)Matches('='); // Optional, to unify handling of initializers as well as tensor-protos in other contexts | |
return Parse(tensorProto, typeProto); | |
} | |
// Parse TensorProto data given its type: | |
Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto) { | |
if (!tensorTypeProto.has_tensor_type()) | |
return ParseError("Error parsing TensorProto (expected a tensor type)."); | |
auto elem_type = tensorTypeProto.tensor_type().elem_type(); | |
tensorProto.set_data_type(elem_type); | |
if (!tensorTypeProto.tensor_type().has_shape()) | |
return ParseError("Error parsing TensorProto (expected a tensor shape)."); | |
for (auto& dim : tensorTypeProto.tensor_type().shape().dim()) { | |
if (!dim.has_dim_value()) | |
return ParseError("Error parsing TensorProto shape (expected numeric dimension)."); | |
auto dimval = dim.dim_value(); | |
tensorProto.add_dims(dimval); | |
} | |
// tensorProto.mutable_int64_data()->Reserve(n); | |
// Parse the actual values: | |
int64_t intval; | |
uint64_t uintval = 0; | |
float floatval = 0.0; | |
double dblval = 0.0; | |
std::string strval; | |
if (Matches('{')) { | |
if (!Matches('}')) { | |
do { | |
switch (static_cast<TensorProto::DataType>(elem_type)) { | |
case TensorProto::DataType::TensorProto_DataType_INT8: | |
case TensorProto::DataType::TensorProto_DataType_INT16: | |
case TensorProto::DataType::TensorProto_DataType_INT32: | |
case TensorProto::DataType::TensorProto_DataType_UINT8: | |
case TensorProto::DataType::TensorProto_DataType_UINT16: | |
case TensorProto::DataType::TensorProto_DataType_BOOL: | |
PARSE_TOKEN(intval); | |
// TODO: check values are in the correct range. | |
tensorProto.add_int32_data(intval); | |
break; | |
case TensorProto::DataType::TensorProto_DataType_INT64: | |
PARSE_TOKEN(intval); | |
tensorProto.add_int64_data(intval); | |
break; | |
case TensorProto::DataType::TensorProto_DataType_UINT32: | |
case TensorProto::DataType::TensorProto_DataType_UINT64: | |
PARSE_TOKEN(uintval); | |
tensorProto.add_uint64_data(uintval); | |
break; | |
case TensorProto::DataType::TensorProto_DataType_FLOAT: | |
PARSE_TOKEN(floatval); | |
tensorProto.add_float_data(floatval); | |
break; | |
case TensorProto::DataType::TensorProto_DataType_DOUBLE: | |
PARSE_TOKEN(dblval); | |
tensorProto.add_double_data(dblval); | |
break; | |
case TensorProto::DataType::TensorProto_DataType_STRING: | |
PARSE_TOKEN(strval); | |
tensorProto.add_string_data(strval); | |
break; | |
default: | |
return ParseError("Unhandled type: %d", elem_type); | |
} | |
} while (Matches(',')); | |
MATCH('}'); | |
} | |
} else if (Matches('[')) { | |
tensorProto.set_data_location(TensorProto::DataLocation::TensorProto_DataLocation_EXTERNAL); | |
auto& externalData = *tensorProto.mutable_external_data(); | |
PARSE(externalData); | |
MATCH(']'); | |
} | |
return Status::OK(); | |
} | |
bool OnnxParser::NextIsIdentifier() { | |
std::string id(""); | |
(void)PeekIdentifier(id); | |
return !(id.empty()); | |
} | |
bool OnnxParser::NextIsType() { | |
std::string id(""); | |
(void)PeekIdentifier(id); | |
if (PrimitiveTypeNameMap::IsTypeName(id)) | |
return true; | |
switch (KeyWordMap::Lookup(id)) { | |
case KeyWordMap::KeyWord::SEQ_TYPE: | |
case KeyWordMap::KeyWord::MAP_TYPE: | |
case KeyWordMap::KeyWord::OPTIONAL_TYPE: | |
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: | |
return true; | |
default: | |
return false; | |
} | |
} | |
Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected) { | |
// Parse a single-value | |
auto next = NextChar(); | |
if (isalpha(next) || next == '_') { | |
if (NextIsType()) { | |
TypeProto typeProto; | |
Parse(typeProto); | |
next = NextChar(); | |
if ((next == '{') || (next == '=') || (NextIsIdentifier())) { | |
attr.set_type(AttributeProto_AttributeType_TENSOR); | |
auto& tensorProto = *attr.mutable_t(); | |
ParseOptionalIdentifier(*tensorProto.mutable_name()); | |
(void)Matches('='); // Optional, to unify handling of initializers | |
Parse(tensorProto, typeProto); | |
} else { | |
attr.set_type(AttributeProto_AttributeType_TYPE_PROTO); | |
attr.mutable_tp()->CopyFrom(typeProto); | |
} | |
} else { | |
if (NextIsValidFloatString()) { | |
Literal literal; | |
PARSE_TOKEN(literal); | |
attr.set_type(AttributeProto_AttributeType_FLOAT); | |
attr.set_f(static_cast<float>(std::stof(literal.value))); | |
} else { | |
attr.set_type(AttributeProto_AttributeType_GRAPH); | |
PARSE(*attr.mutable_g()); | |
} | |
} | |
} else if (Matches('@')) { | |
std::string name; | |
CHECK_PARSER_STATUS(ParseIdentifier(name)); | |
attr.set_ref_attr_name(name); | |
} else { | |
Literal literal; | |
PARSE_TOKEN(literal); | |
switch (literal.type) { | |
case LiteralType::INT_LITERAL: | |
attr.set_type(AttributeProto_AttributeType_INT); | |
attr.set_i(std::stol(literal.value)); | |
break; | |
case LiteralType::FLOAT_LITERAL: | |
attr.set_type(AttributeProto_AttributeType_FLOAT); | |
attr.set_f(static_cast<float>(std::stof(literal.value))); | |
break; | |
case LiteralType::STRING_LITERAL: | |
attr.set_type(AttributeProto_AttributeType_STRING); | |
attr.set_s(literal.value); | |
break; | |
} | |
} | |
if ((expected != AttributeProto_AttributeType_UNDEFINED) && (expected != attr.type())) { | |
// Mismatch between type-annotation and attribute-value. We do an implicit cast | |
// only in the special case of FLOAT type and integral value like 2 | |
if ((expected == AttributeProto_AttributeType_FLOAT) && (attr.type() == AttributeProto_AttributeType_INT)) { | |
attr.set_type(AttributeProto_AttributeType_FLOAT); | |
attr.set_f(static_cast<float>(attr.i())); | |
} else { | |
return ParseError( | |
"Mismatch between expected type ", | |
AttributeProto_AttributeType_Name(expected), | |
" and specified value's type", | |
AttributeProto_AttributeType_Name(attr.type())); | |
} | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(AttributeProto& attr) { | |
attr.Clear(); | |
std::string name; | |
CHECK_PARSER_STATUS(ParseIdentifier(name)); | |
return Parse(attr, name); | |
} | |
bool IsSingletonAttribute(AttributeProto_AttributeType type) { | |
switch (type) { | |
case AttributeProto_AttributeType_FLOAT: | |
case AttributeProto_AttributeType_INT: | |
case AttributeProto_AttributeType_STRING: | |
case AttributeProto_AttributeType_TENSOR: | |
case AttributeProto_AttributeType_GRAPH: | |
case AttributeProto_AttributeType_SPARSE_TENSOR: | |
case AttributeProto_AttributeType_TYPE_PROTO: | |
return true; | |
default: | |
return false; | |
} | |
} | |
AttributeProto_AttributeType ToSingletonType(AttributeProto_AttributeType type) { | |
switch (type) { | |
case AttributeProto_AttributeType_FLOATS: | |
return AttributeProto_AttributeType_FLOAT; | |
case AttributeProto_AttributeType_INTS: | |
return AttributeProto_AttributeType_INT; | |
case AttributeProto_AttributeType_STRINGS: | |
return AttributeProto_AttributeType_STRING; | |
case AttributeProto_AttributeType_TENSORS: | |
return AttributeProto_AttributeType_TENSOR; | |
case AttributeProto_AttributeType_GRAPHS: | |
return AttributeProto_AttributeType_GRAPH; | |
case AttributeProto_AttributeType_SPARSE_TENSORS: | |
return AttributeProto_AttributeType_SPARSE_TENSOR; | |
case AttributeProto_AttributeType_TYPE_PROTOS: | |
return AttributeProto_AttributeType_TYPE_PROTO; | |
default: | |
return type; | |
} | |
} | |
Status OnnxParser::Parse(AttributeProto& attr, std::string& name) { | |
attr.set_name(name); | |
if (Matches(':')) { | |
CHECK_PARSER_STATUS(ParseIdentifier(name)); | |
int attrtype = AttributeTypeNameMap::Lookup(name); | |
if (attrtype != 0) { | |
attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype)); | |
} else { | |
return ParseError("Unexpected attribute type."); | |
} | |
} | |
MATCH('='); | |
if (NextChar() == '[') { | |
// Parse a list of values. For an empty list, the type MUST be specified | |
// using the type-annotation syntax of ": type". | |
std::vector<Literal> vals; | |
MATCH('['); | |
if (NextChar() != ']') { | |
do { | |
AttributeProto nextval; | |
auto expected_type = ToSingletonType(attr.type()); | |
CHECK_PARSER_STATUS(ParseSingleAttributeValue(nextval, expected_type)); | |
switch (nextval.type()) { | |
case AttributeProto_AttributeType_INT: | |
attr.set_type(AttributeProto_AttributeType_INTS); | |
attr.add_ints(nextval.i()); | |
break; | |
case AttributeProto_AttributeType_FLOAT: | |
attr.set_type(AttributeProto_AttributeType_FLOATS); | |
attr.add_floats(nextval.f()); | |
break; | |
case AttributeProto_AttributeType_STRING: | |
attr.add_strings(nextval.s()); | |
attr.set_type(AttributeProto_AttributeType_STRINGS); | |
break; | |
default: | |
break; | |
} | |
} while (Matches(',')); | |
} else { | |
if (attr.type() == AttributeProto_AttributeType_UNDEFINED) | |
return ParseError("Empty list attribute value requires type annotation."); | |
if (IsSingletonAttribute(attr.type())) | |
return ParseError("Singleton attribute value cannot be specified as a list."); | |
} | |
MATCH(']'); | |
} else { | |
CHECK_PARSER_STATUS(ParseSingleAttributeValue(attr, attr.type())); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(AttrList& attrlist) { | |
attrlist.Clear(); | |
if (Matches('<')) { | |
do { | |
PARSE(*attrlist.Add()); | |
} while (Matches(',')); | |
MATCH('>'); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(NodeProto& node) { | |
PARSE(*node.mutable_output()); | |
MATCH('='); | |
std::string domain(""); | |
std::string id; | |
ParseIdentifier(id); | |
while (Matches('.')) { | |
if (!domain.empty()) | |
domain += "."; | |
domain += id; | |
ParseIdentifier(id); | |
} | |
node.set_domain(domain); | |
node.set_op_type(id); | |
if (Matches(':')) { | |
std::string overload; | |
ParseIdentifier(overload); | |
node.set_overload(overload); | |
} | |
PARSE(*node.mutable_attribute()); | |
MATCH('('); | |
PARSE(*node.mutable_input()); | |
MATCH(')'); | |
if (node.attribute_size() == 0) { | |
// Permit attributes to be specified before or after parameters. | |
PARSE(*node.mutable_attribute()); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(NodeList& nodelist) { | |
nodelist.Clear(); | |
MATCH('{'); | |
while (!Matches('}')) { | |
PARSE(*nodelist.Add()); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(GraphProto& graph) { | |
std::string id; | |
ParseIdentifier(id); | |
return Parse(id, graph); | |
} | |
Status OnnxParser::Parse(std::string name, GraphProto& graph) { | |
graph.set_name(name); | |
graph.mutable_initializer()->Clear(); | |
CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer())); | |
MATCH('='); | |
MATCH('>', false); | |
CHECK_PARSER_STATUS(ParseGraphInputOutput(*graph.mutable_output())); | |
CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer())); | |
return Parse(*graph.mutable_node()); | |
} | |
Status OnnxParser::Parse(FunctionProto& fn) { | |
fn.Clear(); | |
std::string strval; | |
if (Matches('<')) { | |
do { | |
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE; | |
PARSE_TOKEN(keyword); | |
MATCH(':'); | |
switch (keyword) { | |
case KeyWordMap::KeyWord::OPSET_IMPORT: | |
PARSE(*fn.mutable_opset_import()); | |
break; | |
case KeyWordMap::KeyWord::DOC_STRING: | |
PARSE_TOKEN(strval); | |
fn.set_doc_string(strval); | |
break; | |
case KeyWordMap::KeyWord::DOMAIN_KW: | |
PARSE_TOKEN(strval); | |
fn.set_domain(strval); | |
break; | |
case KeyWordMap::KeyWord::OVERLOAD_KW: | |
PARSE_TOKEN(strval); | |
fn.set_overload(strval); | |
break; | |
default: | |
return ParseError("Unhandled keyword."); | |
} | |
} while (Matches(',')); | |
MATCH('>'); | |
} | |
std::string id; | |
ParseIdentifier(id); | |
fn.set_name(id); | |
PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>'); | |
fn.mutable_value_info()->Clear(); | |
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.mutable_value_info())); | |
MATCH('='); | |
MATCH('>', false); | |
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.mutable_value_info())); | |
if (NextChar() == '<') { | |
PARSE('<', *fn.mutable_value_info(), '>'); | |
} | |
return Parse(*fn.mutable_node()); | |
} | |
Status OnnxParser::Parse(OpsetIdList& opsets) { | |
std::string strval; | |
int64_t intval = 0; | |
MATCH('['); | |
if (!Matches(']')) { | |
do { | |
auto* import = opsets.Add(); | |
PARSE_TOKEN(strval); | |
import->set_domain(strval); | |
MATCH(':'); | |
PARSE_TOKEN(intval); | |
import->set_version(intval); | |
} while (Matches(',')); | |
MATCH(']'); | |
} | |
return Status::OK(); | |
} | |
Status OnnxParser::Parse(ModelProto& model) { | |
model.Clear(); | |
std::string strval; | |
int64_t intval; | |
if (Matches('<')) { | |
do { | |
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE; | |
PARSE_TOKEN(keyword); | |
MATCH(':'); | |
switch (keyword) { | |
case KeyWordMap::KeyWord::IR_VERSION: | |
PARSE_TOKEN(intval); | |
model.set_ir_version(intval); | |
break; | |
case KeyWordMap::KeyWord::OPSET_IMPORT: | |
PARSE(*model.mutable_opset_import()); | |
break; | |
case KeyWordMap::KeyWord::PRODUCER_NAME: | |
PARSE_TOKEN(strval); | |
model.set_producer_name(strval); | |
break; | |
case KeyWordMap::KeyWord::PRODUCER_VERSION: | |
PARSE_TOKEN(strval); | |
model.set_producer_version(strval); | |
break; | |
case KeyWordMap::KeyWord::DOMAIN_KW: | |
PARSE_TOKEN(strval); | |
model.set_domain(strval); | |
break; | |
case KeyWordMap::KeyWord::MODEL_VERSION: | |
PARSE_TOKEN(intval); | |
model.set_model_version(intval); | |
break; | |
case KeyWordMap::KeyWord::DOC_STRING: | |
PARSE_TOKEN(strval); | |
model.set_doc_string(strval); | |
break; | |
case KeyWordMap::KeyWord::METADATA_PROPS: { | |
auto& metadata_props = *model.mutable_metadata_props(); | |
MATCH('['); | |
if (!Matches(']')) { | |
PARSE(metadata_props); | |
MATCH(']'); | |
} | |
break; | |
} | |
default: | |
return ParseError("Unhandled keyword."); | |
} | |
} while (Matches(',')); | |
MATCH('>'); | |
} | |
PARSE(*model.mutable_graph()); | |
auto* functions = model.mutable_functions(); | |
while (!EndOfInput()) { | |
PARSE(*functions->Add()); | |
} | |
return Status::OK(); | |
} | |
} // namespace ONNX_NAMESPACE | |