From 9e29c1ea613935a161e30c8d2dc2900a0cabf4ff Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Mon, 17 Jun 2024 18:54:16 +0800 Subject: [PATCH] Add data type conversion test Signed-off-by: Cai Yudong --- tests/ut/test_type.cc | 28 ++++++++++++++++++++++++++-- tests/ut/utils.h | 17 +++-------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/tests/ut/test_type.cc b/tests/ut/test_type.cc index 8484bf616..3d68d528b 100644 --- a/tests/ut/test_type.cc +++ b/tests/ut/test_type.cc @@ -20,8 +20,8 @@ TEST_CASE("Test bf16 patch", "[bf16 patch]") { const int64_t nb = 1000, nq = 10; const int64_t dim = 128; - const auto train_ds = GenFloatDataSet(nb, dim); - const auto query_ds = GenFloatDataSet(nq, dim); + const auto train_ds = GenDataSet(nb, dim); + const auto query_ds = GenDataSet(nq, dim); auto train_tensor = reinterpret_cast(train_ds->GetTensor()); @@ -99,3 +99,27 @@ TEST_CASE("Test bf16 patch", "[bf16 patch]") { REQUIRE(ip_dist[i] == ip_dist_new[i]); } } + +template +void +check_data_type_accuracy(float accuracy) { + const int64_t nb = 100; + const int64_t dim = 16; + + auto fp32_base_ds = GenDataSet(nb, dim); + + auto type_base_ds = knowhere::data_type_conversion(*fp32_base_ds); + auto fp32_base_ds_2 = knowhere::data_type_conversion(*type_base_ds); + + auto bv1 = static_cast(fp32_base_ds->GetTensor()); + auto bv2 = static_cast(fp32_base_ds_2->GetTensor()); + + for (int64_t i = 0; i < nb * dim; i++) { + REQUIRE(std::abs(bv1[i] - bv2[i]) < accuracy); + } +} + +TEST_CASE("Test Float16", "[fp16]") { + check_data_type_accuracy(0.001); + check_data_type_accuracy(0.01); +} diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 3f15fee63..ed5b630e1 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -39,20 +39,9 @@ struct DisPairLess { inline knowhere::DataSetPtr GenDataSet(int rows, int dim, int seed = 42) { std::mt19937 rng(seed); - std::uniform_int_distribution<> distrib(0.0, 100.0); - float* ts = new float[rows * dim]; - for (int i = 0; i < rows * dim; ++i) ts[i] = (float)distrib(rng); - auto ds = knowhere::GenDataSet(rows, dim, ts); - ds->SetIsOwner(true); - return ds; -} - -inline knowhere::DataSetPtr -GenFloatDataSet(int rows, int dim, int seed = 42) { - std::mt19937 rng(seed); - std::uniform_real_distribution<> distrib(0.0, 100.0); + std::uniform_real_distribution<> distrib(-1.0, 1.0); float* ts = new float[rows * dim]; - for (int i = 0; i < rows * dim; ++i) ts[i] = (float)distrib(rng); + for (int i = 0; i < rows * dim; ++i) ts[i] = distrib(rng); auto ds = knowhere::GenDataSet(rows, dim, ts); ds->SetIsOwner(true); return ds; @@ -149,7 +138,7 @@ GetKNNRecall(const knowhere::DataSet& ground_truth, const std::vector ids_0(gt_ids + i * gt_k, gt_ids + i * gt_k + gt_k); std::vector ids_1 = result[i];