File size: 6,338 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
#include "onnx/onnx-data_pb.h"
#include "onnx/onnx-operators_pb.h"
#include "onnx/onnx_pb.h"
#include "onnx/string_utils.h"

namespace ONNX_NAMESPACE {
namespace checker {
class ValidationError final : public std::runtime_error {
 public:
  using std::runtime_error::runtime_error;
  const char* what() const noexcept override {
    if (!expanded_message_.empty()) {
      return expanded_message_.c_str();
    }
    return std::runtime_error::what();
  }
  void AppendContext(const std::string& context) {
    expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
  }

 private:
  std::string expanded_message_;
};

#define fail_check(...) \
  ONNX_THROW_EX(ONNX_NAMESPACE::checker::ValidationError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

class CheckerContext final {
 public:
  int get_ir_version() const {
    return ir_version_;
  }
  void set_ir_version(int v) {
    ir_version_ = v;
  }
  const std::unordered_map<std::string, int>& get_opset_imports() const {
    return opset_imports_;
  }
  void set_opset_imports(std::unordered_map<std::string, int> imps) {
    opset_imports_ = std::move(imps);
  }
  bool is_main_graph() const {
    return is_main_graph_;
  }
  void set_is_main_graph(bool is_main_graph) {
    is_main_graph_ = is_main_graph;
  }

  void set_schema_registry(const ISchemaRegistry* schema_registry) {
    schema_registry_ = schema_registry;
  }

  const ISchemaRegistry* get_schema_registry() const {
    return schema_registry_;
  }

  void set_model_dir(const std::string& model_dir) {
    model_dir_ = model_dir;
  }

  std::string get_model_dir() const {
    return model_dir_;
  }

  bool skip_opset_compatibility_check() const {
    return skip_opset_compatibility_check_;
  }

  void set_skip_opset_compatibility_check(bool value) {
    skip_opset_compatibility_check_ = value;
  }

  bool check_custom_domain() const {
    return check_custom_domain_;
  }

  void set_check_custom_domain(bool value) {
    check_custom_domain_ = value;
  }

  explicit CheckerContext() : ir_version_(-1) {}

 private:
  int ir_version_;
  std::unordered_map<std::string, int> opset_imports_;
  bool is_main_graph_ = true;
  const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance();
  std::string model_dir_;
  bool skip_opset_compatibility_check_ = false;
  bool check_custom_domain_ = false;
};

class LexicalScopeContext final {
 public:
  LexicalScopeContext() = default;

  // Construct an instance with the lexical scope from the parent graph to allow
  // lookup of names from that scope via this_or_ancestor_graph_has.
  // The caller must ensure parent_context remains valid for the entire lifetime
  // of the new instance. Alternatively, if that cannot be guaranteed, create an
  // instance with the default constructor and populate output_names with the
  // values from the parent scope so the values are copied instead.
  LexicalScopeContext(const LexicalScopeContext& parent_context) : parent_context_{&parent_context} {}
  LexicalScopeContext& operator=(const LexicalScopeContext& parent_context) {
    parent_context_ = &parent_context;
    return *this;
  }

  void add(const std::string& name) {
    output_names.insert(name);
  }

  bool this_graph_has(const std::string& name) const {
    return output_names.find(name) != output_names.cend();
  }

  bool this_or_ancestor_graph_has(const std::string& name) const {
    return this_graph_has(name) || (parent_context_ && parent_context_->this_or_ancestor_graph_has(name));
  }

  // public for backwards compatibility. please prefer the public interface of
  // this class over directly changing output_names
  std::unordered_set<std::string> output_names;

 private:
  const LexicalScopeContext* parent_context_{nullptr};
};

using IR_VERSION_TYPE = decltype(Version::IR_VERSION);
void check_value_info(const ValueInfoProto& value_info, const CheckerContext&);
void check_tensor(const TensorProto& tensor, const CheckerContext&);
void check_sparse_tensor(const SparseTensorProto& sparse_tensor, const CheckerContext&);
void check_sequence(const SequenceProto& sequence, const CheckerContext&);
void check_map(const MapProto& map, const CheckerContext&);
void check_optional(const OptionalProto& opt, const CheckerContext&);
void check_attribute(const AttributeProto& attr, const CheckerContext&, const LexicalScopeContext&);
void check_node(const NodeProto& node, const CheckerContext&, const LexicalScopeContext&);
void check_graph(const GraphProto& graph, const CheckerContext&, const LexicalScopeContext&);
void check_function(const FunctionProto& function, const CheckerContext&, const LexicalScopeContext&);

// Check schema compatibility for 2 opset versions for a given node.
// Checks whether the schema for 2 versions is same, this is true when the opschema
// does not change between versions.
void check_opset_compatibility(

    const NodeProto& node,

    const CheckerContext& ctx,

    const std::unordered_map<std::string, int>& func_opset_imports,

    const std::unordered_map<std::string, int>& model_opset_imports);

// Checks all model local functions present in ModelProto
void check_model_local_functions(

    const ModelProto& model,

    const CheckerContext& ctx,

    const LexicalScopeContext& parent_lex);

void check_model(

    const ModelProto& model,

    bool full_check = false,

    bool skip_opset_compatibility_check = false,

    bool check_custom_domain = false);
void check_model(

    const std::string& model_path,

    bool full_check = false,

    bool skip_opset_compatibility_check = false,

    bool check_custom_domain = false);
std::string resolve_external_data_location(

    const std::string& base_dir,

    const std::string& location,

    const std::string& tensor_name);
bool check_is_experimental_op(const NodeProto& node);

} // namespace checker
} // namespace ONNX_NAMESPACE