-
Notifications
You must be signed in to change notification settings - Fork 612
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32320 from vespa-engine/havardpe/optimize-sum-max…
…-inv-hamming optimize sum max inv hamming operation
- Loading branch information
Showing
7 changed files
with
337 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
eval/src/tests/instruction/sum_max_inv_hamming_function/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
vespa_add_executable(eval_sum_max_inv_hamming_function_test_app TEST | ||
SOURCES | ||
sum_max_inv_hamming_function_test.cpp | ||
DEPENDS | ||
vespaeval | ||
GTest::GTest | ||
) | ||
vespa_add_test(NAME eval_sum_max_inv_hamming_function_test_app COMMAND eval_sum_max_inv_hamming_function_test_app) |
136 changes: 136 additions & 0 deletions
136
.../src/tests/instruction/sum_max_inv_hamming_function/sum_max_inv_hamming_function_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
|
||
#include <vespa/eval/eval/fast_value.h> | ||
#include <vespa/eval/eval/tensor_function.h> | ||
#include <vespa/eval/eval/test/eval_fixture.h> | ||
#include <vespa/eval/eval/test/gen_spec.h> | ||
#include <vespa/eval/instruction/sum_max_inv_hamming_function.h> | ||
#include <vespa/vespalib/gtest/gtest.h> | ||
|
||
using namespace vespalib; | ||
using namespace vespalib::eval; | ||
using namespace vespalib::eval::test; | ||
|
||
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); | ||
|
||
std::string main_expr = "reduce(reduce(1/(1+reduce(hamming(a,b),sum,z)),max,y),sum,x)"; | ||
std::string alt_expr = "reduce(reduce(1/(reduce(hamming(a,b),sum,z)+1),max,y),sum,x)"; | ||
|
||
//----------------------------------------------------------------------------- | ||
|
||
void assert_optimized(const TensorSpec &a, const TensorSpec &b, size_t vec_size, const std::string &expr = main_expr) { | ||
EvalFixture::ParamRepo param_repo; | ||
param_repo.add("a", a); | ||
param_repo.add("b", b); | ||
EvalFixture slow_fixture(prod_factory, expr, param_repo, false); | ||
EvalFixture fast_fixture(prod_factory, expr, param_repo, true); | ||
EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(main_expr, param_repo)); | ||
EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(main_expr, param_repo)); | ||
auto info = fast_fixture.find_all<SumMaxInvHammingFunction>(); | ||
if (info.size() == 1) { | ||
EXPECT_TRUE(info[0]->result_is_mutable()); | ||
EXPECT_EQ(info[0]->vec_size(), vec_size); | ||
} | ||
EXPECT_EQ(info.size(), 1); | ||
} | ||
|
||
void assert_not_optimized(const TensorSpec &a, const TensorSpec &b, const std::string &expr = main_expr) { | ||
EvalFixture::ParamRepo param_repo; | ||
param_repo.add("a", a); | ||
param_repo.add("b", b); | ||
EvalFixture slow_fixture(prod_factory, expr, param_repo, false); | ||
EvalFixture fast_fixture(prod_factory, expr, param_repo, true); | ||
EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo)); | ||
EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); | ||
auto info = fast_fixture.find_all<SumMaxInvHammingFunction>(); | ||
EXPECT_EQ(info.size(), 0); | ||
} | ||
|
||
//----------------------------------------------------------------------------- | ||
|
||
GenSpec make_spec(const std::string &desc, CellType cell_type) { | ||
return GenSpec::from_desc(desc).cells(cell_type).seq(Seq({0x1f, 0x2e, 0x3d, 0x4c, 0x5b, 0x6a, 0x79, 0x88, | ||
0x97, 0xa6, 0xb5, 0xc4, 0xd3, 0xe2, 0xf1})); | ||
} | ||
|
||
GenSpec query = make_spec("x3_1z7", CellType::INT8); | ||
GenSpec document = make_spec("y5_1z7", CellType::INT8); | ||
|
||
TEST(SumMaxInvHamming, expression_can_be_optimized) { | ||
assert_optimized(query, document, 7); | ||
} | ||
|
||
TEST(SumMaxInvHamming, input_values_can_be_reordered) { | ||
assert_optimized(document, query, 7); | ||
} | ||
|
||
TEST(SumMaxInvHamming, expression_can_have_alternative_form) { | ||
assert_optimized(query, document, 7, alt_expr); | ||
assert_optimized(document, query, 7, alt_expr); | ||
} | ||
|
||
TEST(SumMaxInvHamming, optimization_works_with_empty_tensors) { | ||
auto empty_query = make_spec("x0_0z7", CellType::INT8); | ||
auto empty_document = make_spec("y0_0z7", CellType::INT8); | ||
assert_optimized(empty_query, document, 7); | ||
assert_optimized(query, empty_document, 7); | ||
assert_optimized(empty_query, empty_document, 7); | ||
} | ||
|
||
TEST(SumMaxInvHamming, the_hamming_dimension_may_be_trivial) { | ||
GenSpec trivial_query = make_spec("x3_1z1", CellType::INT8); | ||
GenSpec trivial_document = make_spec("y5_1z1", CellType::INT8); | ||
assert_optimized(trivial_query, trivial_document, 1); | ||
} | ||
|
||
//----------------------------------------------------------------------------- | ||
|
||
TEST(SumMaxInvHamming, other_dimensions_may_be_indexed_as_long_as_hamming_dimension_has_stride_1) { | ||
auto dense_query = make_spec("x3z7", CellType::INT8); | ||
auto dense_document = make_spec("y5z7", CellType::INT8); | ||
assert_optimized(dense_query, dense_document, 7); | ||
|
||
std::string outer_expr = "reduce(reduce(1/(1+reduce(hamming(a,b),sum,y)),max,x),sum,z)"; | ||
auto dense_query2 = make_spec("x3y7", CellType::INT8); | ||
auto dense_document2 = make_spec("y7z5", CellType::INT8); | ||
assert_not_optimized(dense_query2, dense_document2); | ||
} | ||
|
||
//----------------------------------------------------------------------------- | ||
|
||
TEST(SumMaxInvHamming, all_cells_must_be_int8) { | ||
for (auto ct: CellTypeUtils::list_types()) { | ||
if (ct != CellType::INT8) { | ||
assert_not_optimized(query.cpy().cells(ct), document); | ||
assert_not_optimized(query, document.cpy().cells(ct)); | ||
assert_not_optimized(query.cpy().cells(ct), document.cpy().cells(ct)); | ||
} | ||
} | ||
} | ||
|
||
TEST(SumMaxInvHamming, extra_dimensions_are_not_allowed) { | ||
GenSpec query_es = make_spec("a1_1x3_1z7", CellType::INT8); | ||
GenSpec query_ed = make_spec("x3_1w1z7", CellType::INT8); | ||
GenSpec document_es = make_spec("a1_1y5_1z7", CellType::INT8); | ||
GenSpec document_ed = make_spec("y5_1w1z7", CellType::INT8); | ||
assert_not_optimized(query_es, document); | ||
assert_not_optimized(query, document_es); | ||
assert_not_optimized(query_ed, document); | ||
assert_not_optimized(query, document_ed); | ||
assert_not_optimized(query_es, document_es); | ||
assert_not_optimized(query_ed, document_ed); | ||
} | ||
|
||
TEST(SumMaxInvHamming, similar_expressions_are_not_optimized) { | ||
assert_not_optimized(query, document, "reduce(reduce(1*(1+reduce(hamming(a,b),sum,z)),max,y),sum,x)"); | ||
assert_not_optimized(query, document, "reduce(reduce(1/(1-reduce(hamming(a,b),sum,z)),max,y),sum,x)"); | ||
assert_not_optimized(query, document, "reduce(reduce(1/(1+reduce(hamming(a,b),max,z)),max,y),sum,x)"); | ||
assert_not_optimized(query, document, "reduce(reduce(1/(1+reduce(hamming(a,b),sum,z)),sum,y),sum,x)"); | ||
assert_not_optimized(query, document, "reduce(reduce(1/(1+reduce(hamming(a,b),sum,z)),max,y),max,x)"); | ||
assert_not_optimized(query, document, "reduce(reduce(1/(1+reduce(hamming(a,b),sum,y)),max,z),sum,x)"); | ||
assert_not_optimized(query, document, "reduce(reduce(1/(1+reduce(hamming(a,b),sum,x)),max,y),sum,z)"); | ||
} | ||
|
||
//----------------------------------------------------------------------------- | ||
|
||
GTEST_MAIN_RUN_ALL_TESTS() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
eval/src/vespa/eval/instruction/sum_max_inv_hamming_function.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
|
||
#include "sum_max_inv_hamming_function.h" | ||
#include <vespa/eval/eval/value.h> | ||
#include <vespa/vespalib/util/binary_hamming_distance.h> | ||
|
||
namespace vespalib::eval { | ||
|
||
using namespace tensor_function; | ||
using namespace operation; | ||
|
||
namespace { | ||
|
||
void my_sum_max_inv_hamming_op(InterpretedFunction::State &state, uint64_t vec_size) { | ||
double result = 0.0; | ||
auto query_cells = state.peek(1).cells().unsafe_typify<int8_t>(); | ||
auto document_cells = state.peek(0).cells().unsafe_typify<int8_t>(); | ||
if ((query_cells.size() > 0) && (document_cells.size() > 0)) { | ||
for (const int8_t *query = query_cells.data(); query < query_cells.data() + query_cells.size(); query += vec_size) { | ||
float max_inv_hamming = aggr::Max<float>::null_value(); | ||
for (const int8_t *document = document_cells.data(); document < document_cells.data() + document_cells.size(); document += vec_size) { | ||
float my_inv_hamming = 1.0f / (1.0f + binary_hamming_distance(query, document, vec_size)); | ||
max_inv_hamming = aggr::Max<float>::combine(max_inv_hamming, my_inv_hamming); | ||
} | ||
result += max_inv_hamming; | ||
} | ||
} | ||
state.pop_pop_push(state.stash.create<DoubleValue>(result)); | ||
} | ||
|
||
const Reduce *check_reduce(const TensorFunction &expr, Aggr aggr) { | ||
if (auto reduce = as<Reduce>(expr)) { | ||
if ((reduce->aggr() == aggr) && (reduce->dimensions().size() == 1)) { | ||
return reduce; | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
const Join *check_join(const TensorFunction &expr, op2_t op) { | ||
if (auto join = as<Join>(expr)) { | ||
if (join->function() == op) { | ||
return join; | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
bool is_one(const TensorFunction &expr) { | ||
if (expr.result_type().is_double()) { | ||
if (auto const_value = as<ConstValue>(expr)) { | ||
return (const_value->value().as_double() == 1.0); | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
// 1/(1+x) -> x | ||
// 1/(x+1) -> x | ||
const TensorFunction *check_inv(const TensorFunction &expr) { | ||
if (auto div = check_join(expr, Div::f)) { | ||
if (is_one(div->lhs())) { | ||
if (auto add = check_join(div->rhs(), Add::f)) { | ||
if (is_one(add->lhs())) { | ||
return &add->rhs(); | ||
} | ||
if (is_one(add->rhs())) { | ||
return &add->lhs(); | ||
} | ||
} | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
bool check_params(const ValueType &res_type, const ValueType &query, const ValueType &document, | ||
const std::string &sum_dim, const std::string &max_dim, const std::string &ham_dim) | ||
{ | ||
return (res_type.is_double() && | ||
(query.dimensions().size() == 2) && (query.cell_type() == CellType::INT8) && | ||
(document.dimensions().size() == 2) && (document.cell_type() == CellType::INT8) && | ||
query.has_dimension(sum_dim) && (query.stride_of(ham_dim) == 1) && | ||
document.has_dimension(max_dim) && (document.stride_of(ham_dim) == 1)); | ||
} | ||
|
||
size_t get_dim_size(const ValueType &type, const std::string &dim) { | ||
size_t npos = ValueType::Dimension::npos; | ||
size_t idx = type.dimension_index(dim); | ||
assert(idx != npos); | ||
return type.dimensions()[idx].size; | ||
} | ||
|
||
} // namespace <unnamed> | ||
|
||
SumMaxInvHammingFunction::SumMaxInvHammingFunction(const ValueType &res_type_in, | ||
const TensorFunction &query, | ||
const TensorFunction &document, | ||
size_t vec_size) | ||
: tensor_function::Op2(res_type_in, query, document), | ||
_vec_size(vec_size) | ||
{ | ||
} | ||
|
||
InterpretedFunction::Instruction | ||
SumMaxInvHammingFunction::compile_self(const ValueBuilderFactory &, Stash &) const | ||
{ | ||
return InterpretedFunction::Instruction(my_sum_max_inv_hamming_op, _vec_size); | ||
} | ||
|
||
const TensorFunction & | ||
SumMaxInvHammingFunction::optimize(const TensorFunction &expr, Stash &stash) | ||
{ | ||
if (auto sum_reduce = check_reduce(expr, Aggr::SUM)) { | ||
if (auto max_reduce = check_reduce(sum_reduce->child(), Aggr::MAX)) { | ||
if (auto inverted = check_inv(max_reduce->child())) { | ||
if (auto ham_reduce = check_reduce(*inverted, Aggr::SUM)) { | ||
if (auto ham = check_join(ham_reduce->child(), Hamming::f)) { | ||
const auto &sum_dim = sum_reduce->dimensions()[0]; | ||
const auto &max_dim = max_reduce->dimensions()[0]; | ||
const auto &ham_dim = ham_reduce->dimensions()[0]; | ||
if (check_params(expr.result_type(), ham->lhs().result_type(), ham->rhs().result_type(), | ||
sum_dim, max_dim, ham_dim)) | ||
{ | ||
size_t vec_size = get_dim_size(ham->lhs().result_type(), ham_dim); | ||
return stash.create<SumMaxInvHammingFunction>(expr.result_type(), ham->lhs(), ham->rhs(), vec_size); | ||
} | ||
if (check_params(expr.result_type(), ham->rhs().result_type(), ham->lhs().result_type(), | ||
sum_dim, max_dim, ham_dim)) | ||
{ | ||
size_t vec_size = get_dim_size(ham->rhs().result_type(), ham_dim); | ||
return stash.create<SumMaxInvHammingFunction>(expr.result_type(), ham->rhs(), ham->lhs(), vec_size); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
return expr; | ||
} | ||
|
||
} // namespace |
47 changes: 47 additions & 0 deletions
47
eval/src/vespa/eval/instruction/sum_max_inv_hamming_function.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
|
||
#pragma once | ||
|
||
#include <vespa/eval/eval/tensor_function.h> | ||
|
||
namespace vespalib::eval { | ||
|
||
/** | ||
* Tensor function combining multiple inverted hamming distances with | ||
* multiple layers of aggregation, resulting in a single scalar | ||
* result. | ||
* | ||
* inputs: | ||
* query: tensor<int8>(qt{},x[32]) | ||
* document: tensor<int8>(dt{},x[32]) | ||
* | ||
* expression: | ||
* reduce( | ||
* reduce( | ||
* 1/(1+reduce(hamming(query, document), sum, x)), | ||
* max, dt | ||
* ), | ||
* sum, qt | ||
* ) | ||
* | ||
* Both query and document contains a collection of binary int8 | ||
* vectors. For each query vector, take the inverted hamming distance | ||
* against all document vectors and select the maximum result. Sum | ||
* these partial results into the final result value. | ||
**/ | ||
class SumMaxInvHammingFunction : public tensor_function::Op2 | ||
{ | ||
private: | ||
size_t _vec_size; | ||
public: | ||
SumMaxInvHammingFunction(const ValueType &res_type_in, | ||
const TensorFunction &query, | ||
const TensorFunction &document, | ||
size_t vec_size); | ||
InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; | ||
size_t vec_size() const { return _vec_size; } | ||
bool result_is_mutable() const override { return true; } | ||
static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash); | ||
}; | ||
|
||
} // namespace |