Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cluster framework #448

Merged
merged 2 commits into from
Apr 19, 2024
Merged

Conversation

chasingegg
Copy link
Collaborator

issue #444
/kind feature
/hold

namespace knowhere::kmeans {
namespace {

static constexpr int64_t MAX_TRAIN_SIZE = 10000000L * 700 * 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this magic number come from?
also, please use ULL, not L for uint64_t

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

uint64_t npts, dim;
uint32_t npts_32, dim_32;
reader.read((char*)&npts_32, sizeof(uint32_t));
reader.read((char*)&dim_32, sizeof(uint32_t));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are the number of points and dim limited to 32 bits?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, at least the number of points could be very large, change them to uint64_t


template <typename T>
inline bool
load_bin_file(const std::string& fname, std::unique_ptr<T[]>& data, uint64_t& offset) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do I get it correct that this function loads the whole file into a given preallocated buffer into a given offset? If so, then could you please add a comment on what this function does exactly, otherwise it is somewhat confusing. Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


template <typename VecT>
void
KMeans<VecT>::elkan_L2(const VecT* x, const VecT* y, size_t d, size_t nx, size_t ny, uint32_t* ids, float* val) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my concert with elkan that it was MUCH slower that a regular implementatoin in my past experiments, up to 10x slower if I am not mistaken. Faiss uses the following BLAS call: https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/utils/distances.cpp#L263

Would you consider providing a plain BLAS-based implementation as well?

void
KMeans<VecT>::initRandom(const VecT* train_data, size_t n_train, uint32_t random_state) {
std::unordered_set<uint32_t> picked;
std::mt19937 rng(random_state);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random state for std random number generators is uint64_t

size_t start_id = block_id * block_size;
size_t end_id = (std::min)((block_id + 1) * block_size, n_train);
for (size_t id = start_id; id < end_id; id++) {
dist[id] = faiss::fvec_L2sqr(train_data + id * dim_, train_data + init_id * dim_, dim_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use faiss::fvec_L2sqr_ny()

size_t start_id = block_id * block_size;
size_t end_id = (std::min)((block_id + 1) * block_size, n_train);
for (size_t id = start_id; id < end_id; id++) {
dist[id] = faiss::fvec_L2sqr(train_data + id * dim_, train_data + init_id * dim_, dim_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use faiss::fvec_L2sqr_ny()


namespace knowhere::kmeans {

template <typename VecT>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

faiss has faiss::Clustering class, just in case :)


offset = 0;
for (int i = sample_files; i < file_paths.size(); i++) {
uint64_t dumb = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what a variable name :))) I'd use dummy in this case

@mergify mergify bot removed the ci-passed label Mar 11, 2024
@@ -110,7 +110,7 @@ if(__X86_64)
-Wno-unused-function
-Wno-strict-aliasing>)
target_link_libraries(
faiss PUBLIC OpenMP::OpenMP_CXX ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}
faiss PUBLIC OpenMP::OpenMP_CXX openblas ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need openblas here, because there is already ${BLAS_LIBRARIES} O_o

void
KMeans<VecT>::exhaustive_L2sqr_blas(const VecT* x, const VecT* y, size_t d, size_t nx, size_t ny, uint32_t* ids,
float* val) {
static_assert(std::is_same_v<VecT, float>, "sgemm only support float now");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of not just calling faiss::exhaustive_L2sqr_blas() here?

Comment on lines 32 to 33
fit(const VecT* vecs, size_t n, size_t max_iter = 10, uint32_t random_state = 0, std::string_view init = "random",
std::string_view algorithm = "lloyd");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we turn all string_view into enum class?

Comment on lines 295 to 296
inline DataSetPtr
GenResultDataSet(const int64_t dim, const void* tensor, const int64_t rows, const void* centroid_id_mapping) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to stick on returning a DataSetPtr for this Kmeans API?

for (size_t iter = 1; iter <= max_iter; ++iter) {
if (algorithm == "lloyd") {
auto loss = lloyds_iter(vecs, closest_docs, centroid_id_mapping_.get(), closest_centroid_distance.get(), n,
random_state, verbose_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the API the last param of this lloyds_iter is compute_residual, but why we pass verbose_?

Comment on lines 373 to 377
if (compute_residual) {
for (size_t i = 0; i < n_train; ++i) {
losses += closest_centroid_distance[i];
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we do loss computation when compute_residual

template <typename VecT>
float
KMeans<VecT>::lloyds_iter(const VecT* train_data, std::vector<std::vector<uint32_t>>& closest_docs,
uint32_t* closest_centroid, float* closest_centroid_distance, size_t n_train,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to pass closest_centroid_distance in instead of creating it in this function

}
old_loss = loss;
} else {
throw std::runtime_error(std::string("Algorithm: ") + std::string(algorithm) + " not supported yet.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't throw exceptions. Use errorcode instead

@chasingegg
Copy link
Collaborator Author

Make clustering (currently only one kmeans implmentation)the same level as index, so refactor some index-related .h and .cc to the index folder, the same as clustering.

Copy link

codecov bot commented Apr 17, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 71.21%. Comparing base (3c46f4c) to head (c39d17a).
Report is 12 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff            @@
##           main     #448       +/-   ##
=========================================
+ Coverage      0   71.21%   +71.21%     
=========================================
  Files         0       67       +67     
  Lines         0     4387     +4387     
=========================================
+ Hits          0     3124     +3124     
- Misses        0     1263     +1263     

see 67 files with indirect coverage changes

@mergify mergify bot added the ci-passed label Apr 17, 2024

namespace knowhere {

class KmeansConfig : public BaseConfig {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many KNN related param in the BaseConfig

Comment on lines 295 to 303
inline DataSetPtr
GenResultDataSet(const int64_t rows, const void* centroid_id_mapping) {
auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(rows);
ret_ds->SetCentroidIdMapping(centroid_id_mapping);
ret_ds->SetIsOwner(true);
return ret_ds;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put this Genxxx function to the user side.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could reuse current dataset actually

@@ -162,6 +168,17 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
return nullptr;
}

const void*
GetCentroidIdMapping() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And these Getter. We can have a index_dataset_util & another cluster_dataset_util to gather them.

#define CLUSTERING_H

#include "knowhere/binaryset.h"
#include "knowhere/clustering/clustering_node.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a little bit weird to make clustering = index. A better name is needed, some thing like cluster and cluster_operator. These just an example for reference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets go with cluster

Assign(const DataSet& dataset, const Config& cfg) = 0;

// return centroids, must be called after trained
virtual expected<DataSetPtr>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use a general object like DataSet as input/output. I will suggest to add more comments to declare what we need in side this Object. Or we can directly avoid using this.


template <typename DataType>
expected<DataSetPtr>
KmeansClusteringNode<DataType>::Assign(const DataSet& dataset, const Config& cfg) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cfg is not used. Let's either mark it as /* unused */ in the signature or get rid of this param

}
auto rows = dataset.GetRows();
auto vecs = dataset.GetTensor();
knowhere::TimeRecorder build_time("Kmeans assign cost", 2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do a dim check here

Comment on lines 449 to 451
elkan_L2(vecs + start * dim_, centroids, dim_, end - start, num_clusters_, closest_centroid + start,
closest_centroid_distance + start);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is for float. If we don't support other data types, let's simply return error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove elkan

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cluster factory has limited data type to be float, already have a static assert

Comment on lines 454 to 456
for (auto& future : futures) {
future.wait();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Async call need to handle the error

Comment on lines 421 to 423
for (auto& future : futures) {
future.wait();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, let's handle the error

@mergify mergify bot added ci-passed and removed ci-passed labels Apr 17, 2024
@liliu-z liliu-z self-requested a review April 17, 2024 12:53
@mergify mergify bot removed the ci-passed label Apr 18, 2024
@mergify mergify bot added the ci-passed label Apr 18, 2024
@chasingegg chasingegg changed the title Support kmeans api Support cluster framework Apr 18, 2024
@mergify mergify bot added ci-passed and removed ci-passed labels Apr 18, 2024
@mergify mergify bot added ci-passed and removed ci-passed labels Apr 18, 2024
@mergify mergify bot added ci-passed and removed ci-passed labels Apr 19, 2024
@chasingegg
Copy link
Collaborator Author

/unhold

@@ -61,6 +65,7 @@ constexpr const char* INDEX_ENGINE_VERSION = "index_engine_version";
constexpr const char* RETRIEVE_FRIENDLY = "retrieve_friendly";
constexpr const char* DIM = "dim";
constexpr const char* TENSOR = "tensor";
constexpr const char* CENTROID_ID_MAPPING = "centroid_id_mapping";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant

@@ -82,6 +87,7 @@ constexpr const char* TRACE_FLAGS = "trace_flags";
constexpr const char* MATERIALIZED_VIEW_SEARCH_INFO = "materialized_view_search_info";
constexpr const char* MATERIALIZED_VIEW_OPT_FIELDS_PATH = "opt_fields_path";
constexpr const char* MAX_EMPTY_RESULT_BUCKETS = "max_empty_result_buckets";
constexpr const char* NUM_CLUSTERS = "num_clusters";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only used by UT

Signed-off-by: chasingegg <[email protected]>
Signed-off-by: chasingegg <[email protected]>
@mergify mergify bot added ci-passed and removed ci-passed labels Apr 19, 2024
@liliu-z
Copy link
Collaborator

liliu-z commented Apr 19, 2024

/lgtm
Thanks for addressing all comments!

@liliu-z
Copy link
Collaborator

liliu-z commented Apr 19, 2024

/approve

@sre-ci-robot
Copy link
Collaborator

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: chasingegg, liliu-z

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@sre-ci-robot sre-ci-robot merged commit ba2cdc5 into zilliztech:main Apr 19, 2024
11 checks passed
@chasingegg chasingegg deleted the serverless-kmeans branch April 19, 2024 08:40
@chasingegg chasingegg mentioned this pull request May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants