Skip to content

Commit

Permalink
[ONNX] Extend ONNX Frontend with Function LessOrEqual (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#21102)

* Implement less_or_equal in set_1 and set_16

* Create tests for ONNX frontend less_or_equal operator
  • Loading branch information
YaritaiKoto authored Nov 20, 2023
1 parent a5e33f1 commit d9e04c3
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 8 deletions.
35 changes: 35 additions & 0 deletions src/frontends/onnx/frontend/src/op/less_or_equal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "op/less_or_equal.hpp"
OPENVINO_SUPPRESS_DEPRECATED_START

#include "default_opset.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector less_or_equal(const Node& node) {
const auto& input = node.get_ng_inputs();
const auto a = input.at(0);
const auto b = input.at(1);
NGRAPH_CHECK(a.get_element_type() != ov::element::bf16 && b.get_element_type() != ov::element::bf16,
"The input data bfloat16 isn't supported in opset 12");
return {std::make_shared<default_opset::LessEqual>(a, b)};
}
} // namespace set_1

namespace set_16 {
OutputVector less_or_equal(const Node& node) {
const auto& input = node.get_ng_inputs();
const auto a = input.at(0);
const auto b = input.at(1);
return {std::make_shared<default_opset::LessEqual>(a, b)};
}
} // namespace set_16
} // namespace op
} // namespace onnx_import
} // namespace ngraph
OPENVINO_SUPPRESS_DEPRECATED_END
30 changes: 30 additions & 0 deletions src/frontends/onnx/frontend/src/op/less_or_equal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/deprecated.hpp"
OPENVINO_SUPPRESS_DEPRECATED_START

#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {

OutputVector less_or_equal(const Node& node);

} // namespace set_1

namespace set_16 {

OutputVector less_or_equal(const Node& node);

} // namespace set_16
} // namespace op
} // namespace onnx_import
} // namespace ngraph
OPENVINO_SUPPRESS_DEPRECATED_END
3 changes: 3 additions & 0 deletions src/frontends/onnx/frontend/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
#include "op/is_nan.hpp"
#include "op/leaky_relu.hpp"
#include "op/less.hpp"
#include "op/less_or_equal.hpp"
#include "op/log.hpp"
#include "op/log_softmax.hpp"
#include "op/loop.hpp"
Expand Down Expand Up @@ -416,6 +417,8 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("IsNaN", 1, is_nan)
REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu);
REGISTER_OPERATOR("Less", 1, less);
REGISTER_OPERATOR("LessOrEqual", 1, less_or_equal);
REGISTER_OPERATOR("LessOrEqual", 16, less_or_equal);
REGISTER_OPERATOR("Log", 1, log);
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
REGISTER_OPERATOR("LogSoftmax", 13, log_softmax);
Expand Down
53 changes: 53 additions & 0 deletions src/frontends/onnx/tests/models/less_or_equal.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
ir_version: 7
graph {
node {
input: "a"
input: "b"
output: "output"
op_type: "LessOrEqual"
}
name: "LessEqualGraph"
input {
name: "a"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
domain: ""
version: 16
}
53 changes: 53 additions & 0 deletions src/frontends/onnx/tests/models/less_or_equal_broadcast.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
ir_version: 7
graph {
node {
input: "a"
input: "b"
output: "output"
op_type: "LessOrEqual"
}
name: "LessEqualGraph"
input {
name: "a"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
domain: ""
version: 16
}
41 changes: 33 additions & 8 deletions src/frontends/onnx/tests/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#endif
// clang-format on
#include "common_test_utils/all_close.hpp"
#include "common_test_utils/file_utils.hpp"
#include "common_test_utils/ndarray.hpp"
#include "common_test_utils/ov_test_utils.hpp"
#include "ngraph/file_util.hpp"
#include "default_opset.hpp"
#include "openvino/opsets/opset12.hpp"
#include "common_test_utils/test_case.hpp"
#include "common_test_utils/test_control.hpp"
#include "common_test_utils/test_tools.hpp"
#include "common_test_utils/type_prop.hpp"
#include "default_opset.hpp"
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/manager.hpp"
#include "onnx_import/core/null_node.hpp"
#include "onnx_import/onnx.hpp"
#include "onnx_import/onnx_utils.hpp"
#include "common_test_utils/all_close.hpp"
#include "common_test_utils/ndarray.hpp"
#include "common_test_utils/test_control.hpp"
#include "common_test_utils/test_tools.hpp"
#include "common_test_utils/type_prop.hpp"
#include "onnx_utils.hpp"
#include "openvino/opsets/opset12.hpp"

OPENVINO_SUPPRESS_DEPRECATED_START

Expand Down Expand Up @@ -6977,6 +6977,31 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_mm_nms_rotated) {
test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_less_or_equal) {
auto function = onnx_import::import_onnx_model(
file_util::path_join(ov::test::utils::getExecutableDirectory(), SERIALIZED_ZOO, "onnx/less_or_equal.onnx"));

auto test_case = ov::test::TestCase(function, s_device);
test_case.add_input<float>(Shape{5}, {1., 2., 3., 4., 5.});
test_case.add_input<float>(Shape{5}, {3., 3., 3., 3., 3.});
test_case.add_expected_output<bool>(Shape{5}, {true, true, true, false, false});

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_less_or_equal_broadcast) {
auto function = onnx_import::import_onnx_model(file_util::path_join(ov::test::utils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/less_or_equal_broadcast.onnx"));

auto test_case = ov::test::TestCase(function, s_device);
test_case.add_input<float>(Shape{5}, {1., 2., 3., 4., 5.});
test_case.add_input<float>(Shape{1}, {3.});
test_case.add_expected_output<bool>(Shape{5}, {true, true, true, false, false});

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_int) {
auto function = onnx_import::import_onnx_model(file_util::path_join(ov::test::utils::getExecutableDirectory(),
SERIALIZED_ZOO,
Expand Down

0 comments on commit d9e04c3

Please sign in to comment.