Skip to content

Commit

Permalink
Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
REDMOND\ninchen committed Apr 22, 2024
1 parent 1459936 commit b1f4260
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 39 deletions.
6 changes: 4 additions & 2 deletions apps/build_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace po = boost::program_options;

int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, codebook_path, universal_label, label_type;
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
float alpha;
bool use_pq_build, use_opq;
Expand Down Expand Up @@ -59,13 +59,14 @@ int main(int argc, char **argv)
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("codebook_path", po::value<std::string>(&codebook_path)->default_value(""),
program_options_utils::CODEBOOK_PATH);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);

optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
Expand Down Expand Up @@ -146,6 +147,7 @@ int main(int argc, char **argv)
.is_use_opq(use_opq)
.is_pq_dist_build(use_pq_build)
.with_num_pq_chunks(build_PQ_bytes)
.with_pq_codebook_path(codebook_path)
.build();

auto index_factory = diskann::IndexFactory(config);
Expand Down
35 changes: 26 additions & 9 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ namespace po = boost::program_options;

template <typename T, typename LabelT = uint32_t>
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const std::string &query_file, const std::string &truthset_file,
const std::string &codebook_file, const bool use_pq_build, const bool use_opq,
const uint32_t pq_num_chunks, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
Expand Down Expand Up @@ -82,10 +84,11 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
.is_dynamic_index(dynamic)
.is_enable_tags(tags)
.is_concurrent_consolidate(false)
.is_pq_dist_build(false)
.is_use_opq(false)
.with_num_pq_chunks(0)
.is_pq_dist_build(use_pq_build)
.is_use_opq(use_pq_build)
.with_num_pq_chunks(pq_num_chunks)
.with_num_frozen_pts(num_frozen_pts)
.with_pq_codebook_path(codebook_file)
.build();

auto index_factory = diskann::IndexFactory(config);
Expand Down Expand Up @@ -278,10 +281,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
query_filters_file;
uint32_t num_threads, K;
query_filters_file, codebook_path;
uint32_t num_threads, K, build_PQ_bytes;
std::vector<uint32_t> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
bool print_all_recalls, dynamic, tags, show_qps_per_thread, use_pq_build, use_opq;
float fail_if_recall_below = 0.0f;

po::options_description desc{
Expand Down Expand Up @@ -331,6 +334,12 @@ int main(int argc, char **argv)
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("codebook_path", po::value<std::string>(&codebook_path)->default_value(""),
program_options_utils::CODEBOOK_PATH);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);

// Output controls
po::options_description output_controls("Output controls");
Expand All @@ -352,6 +361,8 @@ int main(int argc, char **argv)
return 0;
}
po::notify(vm);
use_pq_build = (build_PQ_bytes > 0);
use_opq = vm["use_opq"].as<bool>();
}
catch (const std::exception &ex)
{
Expand Down Expand Up @@ -420,18 +431,21 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
metric, index_path_prefix, result_path, query_file, gt_file, codebook_path, use_pq_build, use_opq,
build_PQ_bytes, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
metric, index_path_prefix, result_path, query_file, gt_file, codebook_path, use_pq_build, use_opq,
build_PQ_bytes, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
Expand All @@ -446,18 +460,21 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
Expand Down
1 change: 1 addition & 0 deletions include/pq_scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ template <typename T> class PQScratch
uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;
bool preprocessed = false;

PQScratch(size_t graph_degree, size_t aligned_dim);
void initialize(size_t dim, const T *query, const float norm = 1.0f);
Expand Down
1 change: 1 addition & 0 deletions include/program_options_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const char *BUIlD_GRAPH_PQ_BYTES = "Number of PQ bytes to build the index; 0 for
const char *USE_OPQ = "Use Optimized Product Quantization (OPQ).";
const char *LABEL_FILE = "Input label file in txt format for Filtered Index build. The file should contain comma "
"separated filters for each node with each line corresponding to a graph node";
const char *CODEBOOK_PATH = "Path for Codebook/piviot file to use when building PQ";
const char *UNIVERSAL_LABEL =
"Universal label, Use only in conjunction with label file for filtered index build. If a "
"graph node has all the labels against it, we can assign a special universal filter to the "
Expand Down
1 change: 1 addition & 0 deletions include/quantized_distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ template <typename data_t> class QuantizedDistance
// Has to be < ndim
virtual uint32_t get_num_chunks() const = 0;

// Return the pq_table used for quantized distance calculation.
virtual const FixedChunkPQTable &get_pq_table() const = 0;

// Preprocess the query by computing chunk distances from the query vector to
Expand Down
45 changes: 29 additions & 16 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,6 @@ size_t Index<T, TagT, LabelT>::load_data(std::string filename)
copy_aligned_data_from_file<T>(reader, _data, file_num_points, file_dim, _data_store->get_aligned_dim());
#else
_data_store->load(filename); // offset == 0.
// ninchen: _pq_data_store->load or generate pq vector after loading.

/* if(_pq_dist) {convert data_store vector to pq and store in _pq_data_store} */
#endif
return file_num_points;
}
Expand Down Expand Up @@ -848,9 +845,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
assert(id_scratch.size() == 0);

T *aligned_query = scratch->aligned_query();

float *pq_dists = nullptr;

_pq_data_store->preprocess_query(aligned_query, scratch);

if (expanded_nodes.size() > 0 || id_scratch.size() > 0)
Expand Down Expand Up @@ -1031,8 +1026,7 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L

if (!use_filter)
{
// ninchen: Use pq vector for searching.
_pq_data_store->get_vector(location, scratch->aligned_query());
_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);
}
else
Expand All @@ -1047,7 +1041,7 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L
if (_dynamic_index)
tl.unlock();

_pq_data_store->get_vector(location, scratch->aligned_query());
_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true,
_location_to_labels[location], false);

Expand All @@ -1061,7 +1055,7 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L
// clear scratch for finding unfiltered candidates
scratch->clear();

_pq_data_store->get_vector(location, scratch->aligned_query());
_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);

for (auto unfiltered_neighbour : scratch->pool())
Expand Down Expand Up @@ -2210,10 +2204,8 @@ size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K
std::shared_lock<std::shared_timed_mutex> ul(_update_lock);

const std::vector<uint32_t> init_ids = get_init_ids();
_pq_data_store->preprocess_query(query, scratch);

//_distance->preprocess_query(query, _data_store->get_dims(),
// scratch->aligned_query());
_data_store->preprocess_query(query, scratch);
if (!use_filters)
{
const std::vector<LabelT> unused_filter_label;
Expand Down Expand Up @@ -2244,7 +2236,14 @@ size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K

if (res_vectors.size() > 0)
{
_data_store->get_vector(node.id, res_vectors[pos]);
if (_pq_dist)
{
_pq_data_store->get_vector(node.id, res_vectors[pos]);
}
else
{
_data_store->get_vector(node.id, res_vectors[pos]);
}
}

if (distances != nullptr)
Expand Down Expand Up @@ -2838,7 +2837,12 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
assert(_empty_slots.size() == 0); // should not resize if there are empty slots.

_data_store->resize((location_t)new_internal_points);
_pq_data_store->resize((location_t)new_internal_points);

if (_pq_dist)
{
_pq_data_store->resize((location_t)new_internal_points);
}

_graph_store->resize_graph(new_internal_points);
_locks = std::vector<non_recursive_mutex>(new_internal_points);

Expand Down Expand Up @@ -2949,7 +2953,12 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
_label_to_start_id[label] = (uint32_t)fz_location;
_location_to_labels[fz_location] = {label};
_data_store->set_vector((location_t)fz_location, point);
_pq_data_store->set_vector((location_t)fz_location, point);

if (_pq_dist)
{
_pq_data_store->set_vector((location_t)fz_location, point);
}

_frozen_pts_used++;
}
}
Expand Down Expand Up @@ -3011,7 +3020,11 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
tl.unlock();

_data_store->set_vector(location, point); // update datastore
_pq_data_store->set_vector(location, point); // Update PQDataStore

if (_pq_dist)
{
_pq_data_store->set_vector(location, point); // Update PQDataStore
}

// Find and add appropriate graph edges
ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
Expand Down
17 changes: 5 additions & 12 deletions src/pq_data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ template <typename data_t> void PQDataStore<data_t>::populate_data(const data_t
{
std::memmove(_quantized_data + i * _aligned_dim, vectors + i * this->_dim, this->_dim * sizeof(data_t));
}

if (_distance_fn->preprocessing_required())
{
_distance_fn->preprocess_base_points(reinterpret_cast<data_t *>(_quantized_data), this->_aligned_dim, num_pts);
}
}

template <typename data_t> void PQDataStore<data_t>::populate_data(const std::string &filename, const size_t offset)
Expand Down Expand Up @@ -164,7 +159,7 @@ template <typename data_t> void PQDataStore<data_t>::set_vector(const location_t
std::vector<float> vector_float(full_dimension);

diskann::convert_types<data_t, float>(vector, vector_float.data(), 1, full_dimension);
std::vector<uint8_t> compressed_vector(num_chunks * sizeof(uint32_t));
std::vector<uint8_t> compressed_vector(num_chunks * sizeof(data_t));
std::vector<data_t> compressed_vector_T(num_chunks);

generate_pq_data_from_pivots_simplified(vector_float.data(), 1, pq_table.tables, 256 * full_dimension,
Expand All @@ -176,11 +171,6 @@ template <typename data_t> void PQDataStore<data_t>::set_vector(const location_t
size_t offset_in_data = loc * _aligned_dim;
memset(_quantized_data + offset_in_data, 0, _aligned_dim * sizeof(data_t));
memcpy(_quantized_data + offset_in_data, compressed_vector_T.data(), this->_dim * sizeof(data_t));

if (_distance_fn->preprocessing_required())
{
_distance_fn->preprocess_base_points(reinterpret_cast<data_t *>(_quantized_data) + offset_in_data, _aligned_dim, 1);
}
}

template <typename data_t> void PQDataStore<data_t>::prefetch_vector(const location_t loc)
Expand Down Expand Up @@ -254,7 +244,10 @@ void PQDataStore<data_t>::preprocess_query(const data_t *aligned_query, Abstract
throw diskann::ANNException("PQScratch space has not been set in the scratch object.", -1);
}

_pq_distance_fn->preprocess_query(aligned_query, (location_t)this->get_dims(), *pq_scratch);
if (!pq_scratch->preprocessed)
{
_pq_distance_fn->preprocess_query(aligned_query, (location_t)this->get_dims(), *pq_scratch);
}
}

template <typename data_t> float PQDataStore<data_t>::get_distance(const data_t *query, const location_t loc) const
Expand Down
7 changes: 7 additions & 0 deletions src/pq_l2_distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ const FixedChunkPQTable & PQL2Distance<data_t>::get_pq_table() const
template <typename data_t>
void PQL2Distance<data_t>::preprocess_query(const data_t *aligned_query, uint32_t dim, PQScratch<data_t> &scratch)
{
if (scratch.preprocessed)
{
// This query has already been processed.
return;
}

// Copy query vector to float and then to "rotated" query
for (size_t d = 0; d < dim; d++)
{
Expand All @@ -78,6 +84,7 @@ void PQL2Distance<data_t>::preprocess_query(const data_t *aligned_query, uint32_
std::memcpy(scratch.rotated_query, tmp.data(), _pq_table.ndims * sizeof(float));
}
this->prepopulate_chunkwise_distances(scratch.rotated_query, scratch.aligned_pqtable_dist_scratch);
scratch.preprocessed = true;
}

template <typename data_t>
Expand Down

0 comments on commit b1f4260

Please sign in to comment.