Spaces:
Running
Running
File size: 6,338 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
#include "onnx/onnx-data_pb.h"
#include "onnx/onnx-operators_pb.h"
#include "onnx/onnx_pb.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
namespace checker {
class ValidationError final : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
const char* what() const noexcept override {
if (!expanded_message_.empty()) {
return expanded_message_.c_str();
}
return std::runtime_error::what();
}
void AppendContext(const std::string& context) {
expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
}
private:
std::string expanded_message_;
};
#define fail_check(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::checker::ValidationError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
class CheckerContext final {
public:
int get_ir_version() const {
return ir_version_;
}
void set_ir_version(int v) {
ir_version_ = v;
}
const std::unordered_map<std::string, int>& get_opset_imports() const {
return opset_imports_;
}
void set_opset_imports(std::unordered_map<std::string, int> imps) {
opset_imports_ = std::move(imps);
}
bool is_main_graph() const {
return is_main_graph_;
}
void set_is_main_graph(bool is_main_graph) {
is_main_graph_ = is_main_graph;
}
void set_schema_registry(const ISchemaRegistry* schema_registry) {
schema_registry_ = schema_registry;
}
const ISchemaRegistry* get_schema_registry() const {
return schema_registry_;
}
void set_model_dir(const std::string& model_dir) {
model_dir_ = model_dir;
}
std::string get_model_dir() const {
return model_dir_;
}
bool skip_opset_compatibility_check() const {
return skip_opset_compatibility_check_;
}
void set_skip_opset_compatibility_check(bool value) {
skip_opset_compatibility_check_ = value;
}
bool check_custom_domain() const {
return check_custom_domain_;
}
void set_check_custom_domain(bool value) {
check_custom_domain_ = value;
}
explicit CheckerContext() : ir_version_(-1) {}
private:
int ir_version_;
std::unordered_map<std::string, int> opset_imports_;
bool is_main_graph_ = true;
const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance();
std::string model_dir_;
bool skip_opset_compatibility_check_ = false;
bool check_custom_domain_ = false;
};
class LexicalScopeContext final {
public:
LexicalScopeContext() = default;
// Construct an instance with the lexical scope from the parent graph to allow
// lookup of names from that scope via this_or_ancestor_graph_has.
// The caller must ensure parent_context remains valid for the entire lifetime
// of the new instance. Alternatively, if that cannot be guaranteed, create an
// instance with the default constructor and populate output_names with the
// values from the parent scope so the values are copied instead.
LexicalScopeContext(const LexicalScopeContext& parent_context) : parent_context_{&parent_context} {}
LexicalScopeContext& operator=(const LexicalScopeContext& parent_context) {
parent_context_ = &parent_context;
return *this;
}
void add(const std::string& name) {
output_names.insert(name);
}
bool this_graph_has(const std::string& name) const {
return output_names.find(name) != output_names.cend();
}
bool this_or_ancestor_graph_has(const std::string& name) const {
return this_graph_has(name) || (parent_context_ && parent_context_->this_or_ancestor_graph_has(name));
}
// public for backwards compatibility. please prefer the public interface of
// this class over directly changing output_names
std::unordered_set<std::string> output_names;
private:
const LexicalScopeContext* parent_context_{nullptr};
};
using IR_VERSION_TYPE = decltype(Version::IR_VERSION);
void check_value_info(const ValueInfoProto& value_info, const CheckerContext&);
void check_tensor(const TensorProto& tensor, const CheckerContext&);
void check_sparse_tensor(const SparseTensorProto& sparse_tensor, const CheckerContext&);
void check_sequence(const SequenceProto& sequence, const CheckerContext&);
void check_map(const MapProto& map, const CheckerContext&);
void check_optional(const OptionalProto& opt, const CheckerContext&);
void check_attribute(const AttributeProto& attr, const CheckerContext&, const LexicalScopeContext&);
void check_node(const NodeProto& node, const CheckerContext&, const LexicalScopeContext&);
void check_graph(const GraphProto& graph, const CheckerContext&, const LexicalScopeContext&);
void check_function(const FunctionProto& function, const CheckerContext&, const LexicalScopeContext&);
// Check schema compatibility for 2 opset versions for a given node.
// Checks whether the schema for 2 versions is same, this is true when the opschema
// does not change between versions.
void check_opset_compatibility(
const NodeProto& node,
const CheckerContext& ctx,
const std::unordered_map<std::string, int>& func_opset_imports,
const std::unordered_map<std::string, int>& model_opset_imports);
// Checks all model local functions present in ModelProto
void check_model_local_functions(
const ModelProto& model,
const CheckerContext& ctx,
const LexicalScopeContext& parent_lex);
void check_model(
const ModelProto& model,
bool full_check = false,
bool skip_opset_compatibility_check = false,
bool check_custom_domain = false);
void check_model(
const std::string& model_path,
bool full_check = false,
bool skip_opset_compatibility_check = false,
bool check_custom_domain = false);
std::string resolve_external_data_location(
const std::string& base_dir,
const std::string& location,
const std::string& tensor_name);
bool check_is_experimental_op(const NodeProto& node);
} // namespace checker
} // namespace ONNX_NAMESPACE
|