Skip to content

Commit

Permalink
refactor memory management
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva committed May 15, 2024
1 parent e0c9c41 commit 613f86e
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 100 deletions.
102 changes: 102 additions & 0 deletions include/knowhere/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <variant>

#include "comp/index_param.h"
#include "knowhere/range_util.h"
#include "knowhere/sparse_utils.h"

namespace knowhere {
Expand Down Expand Up @@ -71,18 +72,46 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
this->data_[meta::DISTANCE] = Var(std::in_place_index<0>, dis);
}

void
SetDistance(std::unique_ptr<float[]>&& dis) {
std::unique_lock lock(mutex_);
this->data_[meta::DISTANCE] = Var(std::in_place_index<0>, dis.release());
}

void
SetLims(const size_t* lims) {
std::unique_lock lock(mutex_);
this->data_[meta::LIMS] = Var(std::in_place_index<1>, lims);
}

void
SetLims(std::unique_ptr<size_t[]>&& lims) {
std::unique_lock lock(mutex_);
this->data_[meta::LIMS] = Var(std::in_place_index<1>, lims.release());
}

void
SetIds(const int64_t* ids) {
std::unique_lock lock(mutex_);
this->data_[meta::IDS] = Var(std::in_place_index<2>, ids);
}

void
SetIds(std::unique_ptr<long int[]>&& ids) {
static_assert(sizeof(long int) == sizeof(int64_t));

std::unique_lock lock(mutex_);
this->data_[meta::IDS] = Var(std::in_place_index<2>, reinterpret_cast<int64_t*>(ids.release()));
}

void
SetIds(std::unique_ptr<long long int[]>&& ids) {
static_assert(sizeof(long long int) == sizeof(int64_t));

std::unique_lock lock(mutex_);
this->data_[meta::IDS] = Var(std::in_place_index<2>, reinterpret_cast<int64_t*>(ids.release()));
}

/**
* For dense float vector, tensor is a rows * dim float array
* For sparse float vector, tensor is pointer to sparse::Sparse<float>*
Expand All @@ -94,6 +123,18 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor);
}

void
SetTensor(std::unique_ptr<uint8_t[]>&& tensor) {
std::unique_lock lock(mutex_);
this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor.release());
}

void
SetTensor(std::unique_ptr<float[]>&& tensor) {
std::unique_lock lock(mutex_);
this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor.release());
}

void
SetRows(const int64_t rows) {
std::unique_lock lock(mutex_);
Expand Down Expand Up @@ -284,12 +325,34 @@ GenResultDataSet(const int64_t rows, const int64_t dim, const void* tensor) {
return ret_ds;
}

inline DataSetPtr
GenResultDataSet(const int64_t rows, const int64_t dim, std::unique_ptr<uint8_t[]>&& tensor) {
auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(rows);
ret_ds->SetDim(dim);
ret_ds->SetTensor(std::move(tensor));
ret_ds->SetIsOwner(true);
return ret_ds;
}

inline DataSetPtr
GenResultDataSet(const int64_t rows, const int64_t dim, std::unique_ptr<float[]>&& tensor) {
auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(rows);
ret_ds->SetDim(dim);
ret_ds->SetTensor(std::move(tensor));
ret_ds->SetIsOwner(true);
return ret_ds;
}

inline DataSetPtr
#ifdef NOT_COMPILE_FOR_SWIG
GenResultDataSet(const int64_t nq, const int64_t topk, const int64_t* ids, const float* distance) {
#else
GenResultDataSet(const int64_t nq, const int64_t topk, const void* ids, const float* distance) {
#endif
static_assert(sizeof(int64_t) == sizeof(long long int));

auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(nq);
ret_ds->SetDim(topk);
Expand All @@ -299,6 +362,34 @@ GenResultDataSet(const int64_t nq, const int64_t topk, const void* ids, const fl
return ret_ds;
}

inline DataSetPtr
GenResultDataSet(const int64_t nq, const int64_t topk, std::unique_ptr<long int[]>&& ids,
std::unique_ptr<float[]>&& distance) {
static_assert(sizeof(int64_t) == sizeof(long int));

auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(nq);
ret_ds->SetDim(topk);
ret_ds->SetIds(std::move(ids));
ret_ds->SetDistance(std::move(distance));
ret_ds->SetIsOwner(true);
return ret_ds;
}

inline DataSetPtr
GenResultDataSet(const int64_t nq, const int64_t topk, std::unique_ptr<long long int[]>&& ids,
std::unique_ptr<float[]>&& distance) {
static_assert(sizeof(int64_t) == sizeof(long long int));

auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(nq);
ret_ds->SetDim(topk);
ret_ds->SetIds(std::move(ids));
ret_ds->SetDistance(std::move(distance));
ret_ds->SetIsOwner(true);
return ret_ds;
}

inline DataSetPtr
#ifdef NOT_COMPILE_FOR_SWIG
GenResultDataSet(const int64_t nq, const int64_t* ids, const float* distance, const size_t* lims) {
Expand All @@ -314,6 +405,17 @@ GenResultDataSet(const int64_t nq, const void* ids, const float* distance, const
return ret_ds;
}

inline DataSetPtr
GenResultDataSet(const int64_t nq, RangeSearchResult&& range_search_result) {
auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(nq);
ret_ds->SetIds(std::move(range_search_result.labels));
ret_ds->SetDistance(std::move(range_search_result.distances));
ret_ds->SetLims(std::move(range_search_result.lims));
ret_ds->SetIsOwner(true);
return ret_ds;
}

inline DataSetPtr
GenResultDataSet(const std::string& json_info, const std::string& json_id_set) {
auto ret_ds = std::make_shared<DataSet>();
Expand Down
9 changes: 3 additions & 6 deletions include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,9 @@ class IndexNode : public Object {
}
#endif

int64_t* ids = nullptr;
float* dis = nullptr;
size_t* lims = nullptr;
GetRangeSearchResult(result_dist_array, result_id_array, similarity_metric, nq, radius, range_filter, dis, ids,
lims);
return GenResultDataSet(nq, ids, dis, lims);
auto range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, similarity_metric, nq, radius, range_filter);
return GenResultDataSet(nq, std::move(range_search_result));
}

virtual expected<DataSetPtr>
Expand Down
12 changes: 10 additions & 2 deletions include/knowhere/range_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#pragma once

#include <cstddef>
#include <memory>
#include <vector>

#include "knowhere/bitsetview.h"
Expand All @@ -26,9 +28,15 @@ void
FilterRangeSearchResultForOneNq(std::vector<float>& distances, std::vector<int64_t>& labels, const bool is_ip,
const float radius, const float range_filter);

void
struct RangeSearchResult {
std::unique_ptr<float[]> distances;
std::unique_ptr<int64_t[]> labels;
std::unique_ptr<size_t[]> lims;
};

RangeSearchResult
GetRangeSearchResult(const std::vector<std::vector<float>>& result_distances,
const std::vector<std::vector<int64_t>>& result_labels, const bool is_ip, const int64_t nq,
const float radius, const float range_filter, float*& distances, int64_t*& labels, size_t*& lims);
const float radius, const float range_filter);

} // namespace knowhere
12 changes: 5 additions & 7 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto res = GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());
auto res = GenResultDataSet(nq, cfg.k.value(), std::move(labels), std::move(distances));

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
if (cfg.trace_id.has_value()) {
Expand Down Expand Up @@ -432,11 +432,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}

int64_t* ids = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
auto res = GenResultDataSet(nq, ids, distances, lims);
auto range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter);
auto res = GenResultDataSet(nq, std::move(range_search_result));

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
if (cfg.trace_id.has_value()) {
Expand Down Expand Up @@ -550,7 +548,7 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
auto distances = std::make_unique<float[]>(nq * topk);

SearchSparseWithBuf(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset);
return GenResultDataSet(nq, topk, labels.release(), distances.release());
return GenResultDataSet(nq, topk, std::move(labels), std::move(distances));
}

template <typename DataType>
Expand Down
20 changes: 13 additions & 7 deletions src/common/range_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

#include <algorithm>
#include <cinttypes>
#include <cstddef>
#include <memory>
#include <tuple>
#include <vector>

#include "knowhere/log.h"
namespace knowhere {
Expand Down Expand Up @@ -41,16 +45,16 @@ FilterRangeSearchResultForOneNq(std::vector<float>& distances, std::vector<int64
}
}

void
RangeSearchResult
GetRangeSearchResult(const std::vector<std::vector<float>>& result_distances,
const std::vector<std::vector<int64_t>>& result_labels, const bool is_ip, const int64_t nq,
const float radius, const float range_filter, float*& distances, int64_t*& labels, size_t*& lims) {
const float radius, const float range_filter) {
KNOWHERE_THROW_IF_NOT_FMT(result_distances.size() == (size_t)nq, "result distances size %ld not equal to %" SCNd64,
result_distances.size(), nq);
KNOWHERE_THROW_IF_NOT_FMT(result_labels.size() == (size_t)nq, "result labels size %ld not equal to %" SCNd64,
result_labels.size(), nq);

lims = new size_t[nq + 1];
auto lims = std::make_unique<size_t[]>(nq + 1);
lims[0] = 0;
// all distances must be in range scope
for (int64_t i = 0; i < nq; i++) {
Expand All @@ -61,13 +65,15 @@ GetRangeSearchResult(const std::vector<std::vector<float>>& result_distances,
LOG_KNOWHERE_DEBUG_ << "Range search: is_ip " << (is_ip ? "True" : "False") << ", radius " << radius
<< ", range_filter " << range_filter << ", total result num " << total_valid;

distances = new float[total_valid];
labels = new int64_t[total_valid];
auto distances = std::make_unique<float[]>(total_valid);
auto labels = std::make_unique<int64_t[]>(total_valid);

for (auto i = 0; i < nq; i++) {
std::copy_n(result_distances[i].data(), lims[i + 1] - lims[i], distances + lims[i]);
std::copy_n(result_labels[i].data(), lims[i + 1] - lims[i], labels + lims[i]);
std::copy_n(result_distances[i].data(), lims[i + 1] - lims[i], distances.get() + lims[i]);
std::copy_n(result_labels[i].data(), lims[i + 1] - lims[i], labels.get() + lims[i]);
}

return RangeSearchResult{.distances = std::move(distances), .labels = std::move(labels), .lims = std::move(lims)};
}

} // namespace knowhere
12 changes: 4 additions & 8 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ DiskANNIndexNode<DataType>::Search(const DataSet& dataset, const Config& cfg, co
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());
auto res = GenResultDataSet(nq, k, std::move(p_id), std::move(p_dist));

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -600,10 +600,6 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf
auto nq = dataset.GetRows();
auto xq = static_cast<const DataType*>(dataset.GetTensor());

int64_t* p_id = nullptr;
DistType* p_dist = nullptr;
size_t* p_lims = nullptr;

std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<DistType>> result_dist_array(nq);

Expand All @@ -628,9 +624,9 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, search_conf.range_filter.value(),
p_dist, p_id, p_lims);
return GenResultDataSet(nq, p_id, p_dist, p_lims);
auto range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, search_conf.range_filter.value());
return GenResultDataSet(nq, std::move(range_search_result));
}

/*
Expand Down
10 changes: 4 additions & 6 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ class FlatIndexNode : public IndexNode {
float range_filter = f_cfg.range_filter.value();
bool is_ip = (index_->metric_type == faiss::METRIC_INNER_PRODUCT);

int64_t* ids = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
RangeSearchResult range_search_result;

std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<float>> result_dist_array(nq);
Expand Down Expand Up @@ -208,14 +206,14 @@ class FlatIndexNode : public IndexNode {
}
// wait for the completion
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids,
lims);
range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter);
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what();
return expected<DataSetPtr>::Err(Status::faiss_inner_error, e.what());
}

return GenResultDataSet(nq, ids, distances, lims);
return GenResultDataSet(nq, std::move(range_search_result));
}

expected<DataSetPtr>
Expand Down
12 changes: 4 additions & 8 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class HnswIndexNode : public IndexNode {
}
WaitAllSuccess(futs);

auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());
auto res = GenResultDataSet(nq, k, std::move(p_id), std::move(p_dist));

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -358,15 +358,11 @@ class HnswIndexNode : public IndexNode {
}
WaitAllSuccess(futs);

int64_t* ids = nullptr;
DistType* dis = nullptr;
size_t* lims = nullptr;

// filter range search result
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter, dis, ids,
lims);
auto range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter);

auto res = GenResultDataSet(nq, ids, dis, lims);
auto res = GenResultDataSet(nq, std::move(range_search_result));

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down
Loading

0 comments on commit 613f86e

Please sign in to comment.