From 5f34bad092994bdf4d81cdb49656ee6a643011c2 Mon Sep 17 00:00:00 2001 From: usamoi Date: Mon, 2 Sep 2024 11:59:22 +0800 Subject: [PATCH 1/2] refactor: merge IvfNaive and IvfResidual Signed-off-by: usamoi --- crates/ivf/src/ivf_naive.rs | 191 --------------------- crates/ivf/src/ivf_residual.rs | 196 --------------------- crates/ivf/src/lib.rs | 229 +++++++++++++++++++------ crates/quantization/src/lib.rs | 8 + crates/quantization/src/product/mod.rs | 5 + crates/quantization/src/scalar/mod.rs | 4 + crates/quantization/src/trivial/mod.rs | 5 + crates/rabitq/src/lib.rs | 22 ++- 8 files changed, 207 insertions(+), 453 deletions(-) delete mode 100644 crates/ivf/src/ivf_naive.rs delete mode 100644 crates/ivf/src/ivf_residual.rs diff --git a/crates/ivf/src/ivf_naive.rs b/crates/ivf/src/ivf_naive.rs deleted file mode 100644 index a4b2e0cba..000000000 --- a/crates/ivf/src/ivf_naive.rs +++ /dev/null @@ -1,191 +0,0 @@ -use super::OperatorIvf as Op; -use base::always_equal::AlwaysEqual; -use base::index::*; -use base::operator::*; -use base::search::*; -use base::vector::VectorBorrowed; -use common::json::Json; -use common::mmap_array::MmapArray; -use common::remap::RemappedCollection; -use common::vec2::Vec2; -use k_means::k_means; -use k_means::k_means_lookup; -use k_means::k_means_lookup_many; -use quantization::Quantization; -use rayon::iter::IntoParallelIterator; -use rayon::iter::ParallelIterator; -use std::fs::create_dir; -use std::path::Path; -use stoppable_rayon as rayon; -use storage::Storage; - -pub struct IvfNaive { - storage: O::Storage, - quantization: Quantization, - payloads: MmapArray, - offsets: Json>, - centroids: Json::Scalar>>, -} - -impl IvfNaive { - pub fn create( - path: impl AsRef, - options: IndexOptions, - source: &(impl Vectors> + Collection + Source + Sync), - ) -> Self { - let remapped = RemappedCollection::from_source(source); - from_nothing(path, options, &remapped) - } - - pub fn open(path: impl AsRef) -> Self { - open(path) - } - - pub fn dims(&self) -> u32 { - self.storage.dims() - } - - pub fn len(&self) -> u32 { - self.storage.len() - } - - pub fn vector(&self, i: u32) -> Borrowed<'_, O> { - self.storage.vector(i) - } - - pub fn payload(&self, i: u32) -> Payload { - self.payloads[i as usize] - } - - pub fn vbase<'a>( - &'a self, - vector: Borrowed<'a, O>, - opts: &'a SearchOptions, - ) -> Box + 'a> { - let lists = select( - k_means_lookup_many(O::interpret(vector), &self.centroids), - opts.ivf_nprobe as usize, - ); - let mut heap = Vec::new(); - let preprocessed = self.quantization.preprocess(vector); - for i in lists.iter().map(|(_, i)| *i) { - let start = self.offsets[i]; - let end = self.offsets[i + 1]; - self.quantization.push_batch( - &preprocessed, - start..end, - &mut heap, - opts.ivf_sq_fast_scan, - opts.ivf_pq_fast_scan, - ); - } - let mut reranker = self.quantization.flat_rerank( - heap, - move |u| (O::distance(vector, self.storage.vector(u)), ()), - opts.ivf_sq_rerank_size, - opts.ivf_pq_rerank_size, - ); - Box::new(std::iter::from_fn(move || { - reranker.pop().map(|(dis_u, u, ())| Element { - distance: dis_u, - payload: AlwaysEqual(self.payload(u)), - }) - })) - } -} - -fn from_nothing( - path: impl AsRef, - options: IndexOptions, - collection: &(impl Vectors> + Collection + Sync), -) -> IvfNaive { - create_dir(path.as_ref()).unwrap(); - let IvfIndexingOptions { - nlist, - spherical_centroids, - residual_quantization: _, - quantization: quantization_options, - } = options.indexing.clone().unwrap_ivf(); - let samples = O::sample(collection, nlist); - rayon::check(); - let centroids = k_means(nlist as usize, samples, true, spherical_centroids, false); - rayon::check(); - let ls = (0..collection.len()) - .into_par_iter() - .fold( - || vec![Vec::new(); nlist as usize], - |mut state, i| { - state[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); - state - }, - ) - .reduce( - || vec![Vec::new(); nlist as usize], - |lhs, rhs| { - std::iter::zip(lhs, rhs) - .map(|(lhs, rhs)| { - let mut x = lhs; - x.extend(rhs); - x - }) - .collect() - }, - ); - let mut offsets = vec![0u32; nlist as usize + 1]; - for i in 0..nlist { - offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32; - } - let remap = ls - .into_iter() - .flat_map(|x| x.into_iter()) - .collect::>(); - let collection = RemappedCollection::from_collection(collection, remap); - rayon::check(); - let storage = O::Storage::create(path.as_ref().join("storage"), &collection); - let quantization = Quantization::::create( - path.as_ref().join("quantization"), - options.vector, - quantization_options, - &collection, - |vector| vector.own(), - ); - let payloads = MmapArray::create( - path.as_ref().join("payloads"), - (0..collection.len()).map(|i| collection.payload(i)), - ); - let offsets = Json::create(path.as_ref().join("offsets"), offsets); - let centroids = Json::create(path.as_ref().join("centroids"), centroids); - IvfNaive { - storage, - quantization, - payloads, - offsets, - centroids, - } -} - -fn open(path: impl AsRef) -> IvfNaive { - let storage = O::Storage::open(path.as_ref().join("storage")); - let quantization = Quantization::open(path.as_ref().join("quantization")); - let payloads = MmapArray::open(path.as_ref().join("payloads")); - let offsets = Json::open(path.as_ref().join("offsets")); - let centroids = Json::open(path.as_ref().join("centroids")); - IvfNaive { - storage, - quantization, - payloads, - offsets, - centroids, - } -} - -fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> { - if lists.is_empty() || n == 0 { - return Vec::new(); - } - let n = n.min(lists.len()); - lists.select_nth_unstable_by(n - 1, |(x, _), (y, _)| f32::total_cmp(x, y)); - lists.truncate(n); - lists.sort_by(|(x, _), (y, _)| f32::total_cmp(x, y)); - lists -} diff --git a/crates/ivf/src/ivf_residual.rs b/crates/ivf/src/ivf_residual.rs deleted file mode 100644 index 9a6caa0d4..000000000 --- a/crates/ivf/src/ivf_residual.rs +++ /dev/null @@ -1,196 +0,0 @@ -use super::OperatorIvf as Op; -use base::always_equal::AlwaysEqual; -use base::index::*; -use base::operator::*; -use base::search::*; -use base::vector::*; -use common::json::Json; -use common::mmap_array::MmapArray; -use common::remap::RemappedCollection; -use common::vec2::Vec2; -use k_means::k_means; -use k_means::k_means_lookup; -use k_means::k_means_lookup_many; -use quantization::Quantization; -use rayon::iter::IntoParallelIterator; -use rayon::iter::ParallelIterator; -use std::fs::create_dir; -use std::path::Path; -use stoppable_rayon as rayon; -use storage::Storage; - -pub struct IvfResidual { - storage: O::Storage, - quantization: Quantization, - payloads: MmapArray, - offsets: Json>, - centroids: Json::Scalar>>, -} - -impl IvfResidual { - pub fn create( - path: impl AsRef, - options: IndexOptions, - source: &(impl Vectors> + Collection + Source + Sync), - ) -> Self { - let remapped = RemappedCollection::from_source(source); - from_nothing(path, options, &remapped) - } - - pub fn open(path: impl AsRef) -> Self { - open(path) - } - - pub fn dims(&self) -> u32 { - self.storage.dims() - } - - pub fn len(&self) -> u32 { - self.storage.len() - } - - pub fn vector(&self, i: u32) -> Borrowed<'_, O> { - self.storage.vector(i) - } - - pub fn payload(&self, i: u32) -> Payload { - self.payloads[i as usize] - } - - pub fn vbase<'a>( - &'a self, - vector: Borrowed<'a, O>, - opts: &'a SearchOptions, - ) -> Box + 'a> { - let lists = select( - k_means_lookup_many(O::interpret(vector), &self.centroids), - opts.ivf_nprobe as usize, - ); - let mut heap = Vec::new(); - for i in lists.iter().map(|(_, i)| *i) { - let preprocessed = self - .quantization - .preprocess(O::residual(vector, &self.centroids[(i,)]).as_borrowed()); - let start = self.offsets[i]; - let end = self.offsets[i + 1]; - self.quantization.push_batch( - &preprocessed, - start..end, - &mut heap, - opts.ivf_sq_fast_scan, - opts.ivf_pq_fast_scan, - ); - } - let mut reranker = self.quantization.flat_rerank( - heap, - move |u| (O::distance(vector, self.storage.vector(u)), ()), - opts.ivf_sq_rerank_size, - opts.ivf_pq_rerank_size, - ); - Box::new(std::iter::from_fn(move || { - reranker.pop().map(|(dis_u, u, ())| Element { - distance: dis_u, - payload: AlwaysEqual(self.payload(u)), - }) - })) - } -} - -fn from_nothing( - path: impl AsRef, - options: IndexOptions, - collection: &(impl Vectors> + Collection + Sync), -) -> IvfResidual { - create_dir(path.as_ref()).unwrap(); - let IvfIndexingOptions { - nlist, - spherical_centroids, - residual_quantization: _, - quantization: quantization_options, - } = options.indexing.clone().unwrap_ivf(); - let samples = O::sample(collection, nlist); - rayon::check(); - let centroids = k_means(nlist as usize, samples, true, spherical_centroids, false); - rayon::check(); - let ls = (0..collection.len()) - .into_par_iter() - .fold( - || vec![Vec::new(); nlist as usize], - |mut state, i| { - state[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); - state - }, - ) - .reduce( - || vec![Vec::new(); nlist as usize], - |lhs, rhs| { - std::iter::zip(lhs, rhs) - .map(|(lhs, rhs)| { - let mut x = lhs; - x.extend(rhs); - x - }) - .collect() - }, - ); - let mut offsets = vec![0u32; nlist as usize + 1]; - for i in 0..nlist { - offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32; - } - let remap = ls - .into_iter() - .flat_map(|x| x.into_iter()) - .collect::>(); - let collection = RemappedCollection::from_collection(collection, remap); - rayon::check(); - let storage = O::Storage::create(path.as_ref().join("storage"), &collection); - let quantization = Quantization::::create( - path.as_ref().join("quantization"), - options.vector, - quantization_options, - &collection, - |vector| { - let target = k_means_lookup(O::interpret(vector), ¢roids); - O::residual(vector, ¢roids[(target,)]) - }, - ); - let payloads = MmapArray::create( - path.as_ref().join("payloads"), - (0..collection.len()).map(|i| collection.payload(i)), - ); - let offsets = Json::create(path.as_ref().join("offsets"), offsets); - let centroids = Json::create(path.as_ref().join("centroids"), centroids); - IvfResidual { - storage, - payloads, - offsets, - centroids, - quantization, - } -} - -fn open(path: impl AsRef) -> IvfResidual { - let storage = O::Storage::open(path.as_ref().join("storage")); - let quantization = Quantization::open(path.as_ref().join("quantization")); - let payloads = MmapArray::open(path.as_ref().join("payloads")); - let offsets = Json::open(path.as_ref().join("offsets")); - let centroids = Json::open(path.as_ref().join("centroids")); - IvfResidual { - storage, - quantization, - payloads, - offsets, - centroids, - } -} - -fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> { - if lists.is_empty() || n == 0 { - return Vec::new(); - } - let n = n.min(lists.len()); - lists.select_nth_unstable_by(n - 1, |(x, _), (y, _)| f32::total_cmp(x, y)); - lists.truncate(n); - lists.sort_by(|(x, _), (y, _)| f32::total_cmp(x, y)); - lists -} diff --git a/crates/ivf/src/lib.rs b/crates/ivf/src/lib.rs index f04322373..fff65c756 100644 --- a/crates/ivf/src/lib.rs +++ b/crates/ivf/src/lib.rs @@ -1,89 +1,67 @@ #![allow(clippy::len_without_is_empty)] #![allow(clippy::needless_range_loop)] -pub mod ivf_naive; -pub mod ivf_residual; pub mod operator; -use self::ivf_naive::IvfNaive; -use crate::operator::OperatorIvf; +use base::always_equal::AlwaysEqual; use base::index::*; use base::operator::*; use base::search::*; -use common::variants::variants; -use ivf_residual::IvfResidual; +use base::vector::VectorBorrowed; +use base::vector::VectorOwned; +use common::json::Json; +use common::mmap_array::MmapArray; +use common::remap::RemappedCollection; +use common::vec2::Vec2; +use k_means::k_means; +use k_means::k_means_lookup; +use k_means::k_means_lookup_many; +use operator::OperatorIvf as Op; +use quantization::Quantization; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; +use std::fs::create_dir; use std::path::Path; +use stoppable_rayon as rayon; +use storage::Storage; -pub enum Ivf { - Naive(IvfNaive), - Residual(IvfResidual), +pub struct Ivf { + storage: O::Storage, + quantization: Quantization, + payloads: MmapArray, + offsets: Json>, + centroids: Json::Scalar>>, + is_residual: Json, } -impl Ivf { +impl Ivf { pub fn create( path: impl AsRef, options: IndexOptions, source: &(impl Vectors> + Collection + Source + Sync), ) -> Self { - let IvfIndexingOptions { - quantization: quantization_options, - residual_quantization, - .. - } = options.indexing.clone().unwrap_ivf(); - std::fs::create_dir(path.as_ref()).unwrap(); - let this = if !residual_quantization - || matches!(quantization_options, QuantizationOptions::Trivial(_)) - || !O::SUPPORT_RESIDUAL - { - Self::Naive(IvfNaive::create( - path.as_ref().join("ivf_naive"), - options, - source, - )) - } else { - Self::Residual(IvfResidual::create( - path.as_ref().join("ivf_residual"), - options, - source, - )) - }; - this + let remapped = RemappedCollection::from_source(source); + from_nothing(path, options, &remapped) } pub fn open(path: impl AsRef) -> Self { - match variants(path.as_ref(), ["ivf_naive", "ivf_residual"]) { - "ivf_naive" => Self::Naive(IvfNaive::open(path.as_ref().join("ivf_naive"))), - "ivf_residual" => Self::Residual(IvfResidual::open(path.as_ref().join("ivf_residual"))), - _ => unreachable!(), - } + open(path) } pub fn dims(&self) -> u32 { - match self { - Ivf::Naive(x) => x.dims(), - Ivf::Residual(x) => x.dims(), - } + self.storage.dims() } pub fn len(&self) -> u32 { - match self { - Ivf::Naive(x) => x.len(), - Ivf::Residual(x) => x.len(), - } + self.storage.len() } pub fn vector(&self, i: u32) -> Borrowed<'_, O> { - match self { - Ivf::Naive(x) => x.vector(i), - Ivf::Residual(x) => x.vector(i), - } + self.storage.vector(i) } pub fn payload(&self, i: u32) -> Payload { - match self { - Ivf::Naive(x) => x.payload(i), - Ivf::Residual(x) => x.payload(i), - } + self.payloads[i as usize] } pub fn vbase<'a>( @@ -91,9 +69,146 @@ impl Ivf { vector: Borrowed<'a, O>, opts: &'a SearchOptions, ) -> Box + 'a> { - match self { - Ivf::Naive(x) => x.vbase(vector, opts), - Ivf::Residual(x) => x.vbase(vector, opts), + let lists = select( + k_means_lookup_many(O::interpret(vector), &self.centroids), + opts.ivf_nprobe as usize, + ); + let mut heap = Vec::new(); + let mut preprocessed = self.quantization.preprocess(vector); + for i in lists.iter().map(|(_, i)| *i) { + if *self.is_residual { + let vector = O::residual(vector, &self.centroids[(i,)]); + preprocessed = self.quantization.preprocess(vector.as_borrowed()); + } + let start = self.offsets[i]; + let end = self.offsets[i + 1]; + self.quantization.push_batch( + &preprocessed, + start..end, + &mut heap, + opts.ivf_sq_fast_scan, + opts.ivf_pq_fast_scan, + ); } + let mut reranker = self.quantization.flat_rerank( + heap, + move |u| (O::distance(vector, self.storage.vector(u)), ()), + opts.ivf_sq_rerank_size, + opts.ivf_pq_rerank_size, + ); + Box::new(std::iter::from_fn(move || { + reranker.pop().map(|(dis_u, u, ())| Element { + distance: dis_u, + payload: AlwaysEqual(self.payload(u)), + }) + })) + } +} + +fn from_nothing( + path: impl AsRef, + options: IndexOptions, + collection: &(impl Vectors> + Collection + Sync), +) -> Ivf { + create_dir(path.as_ref()).unwrap(); + let IvfIndexingOptions { + nlist, + spherical_centroids, + residual_quantization, + quantization: quantization_options, + } = options.indexing.clone().unwrap_ivf(); + let samples = O::sample(collection, nlist); + rayon::check(); + let centroids = k_means(nlist as usize, samples, true, spherical_centroids, false); + rayon::check(); + let ls = (0..collection.len()) + .into_par_iter() + .fold( + || vec![Vec::new(); nlist as usize], + |mut state, i| { + state[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); + state + }, + ) + .reduce( + || vec![Vec::new(); nlist as usize], + |lhs, rhs| { + std::iter::zip(lhs, rhs) + .map(|(lhs, rhs)| { + let mut x = lhs; + x.extend(rhs); + x + }) + .collect() + }, + ); + let mut offsets = vec![0u32; nlist as usize + 1]; + for i in 0..nlist { + offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32; + } + let remap = ls + .into_iter() + .flat_map(|x| x.into_iter()) + .collect::>(); + let collection = RemappedCollection::from_collection(collection, remap); + let is_residual = residual_quantization && O::SUPPORT_RESIDUAL; + rayon::check(); + let storage = O::Storage::create(path.as_ref().join("storage"), &collection); + let quantization = Quantization::::create( + path.as_ref().join("quantization"), + options.vector, + quantization_options, + &collection, + |vector| { + if is_residual { + let target = k_means_lookup(O::interpret(vector), ¢roids); + O::residual(vector, ¢roids[(target,)]) + } else { + vector.own() + } + }, + ); + let payloads = MmapArray::create( + path.as_ref().join("payloads"), + (0..collection.len()).map(|i| collection.payload(i)), + ); + let offsets = Json::create(path.as_ref().join("offsets"), offsets); + let centroids = Json::create(path.as_ref().join("centroids"), centroids); + let is_residual = Json::create(path.as_ref().join("is_residual"), is_residual); + Ivf { + storage, + quantization, + payloads, + offsets, + centroids, + is_residual, + } +} + +fn open(path: impl AsRef) -> Ivf { + let storage = O::Storage::open(path.as_ref().join("storage")); + let quantization = Quantization::open(path.as_ref().join("quantization")); + let payloads = MmapArray::open(path.as_ref().join("payloads")); + let offsets = Json::open(path.as_ref().join("offsets")); + let centroids = Json::open(path.as_ref().join("centroids")); + let is_residual = Json::open(path.as_ref().join("is_residual")); + Ivf { + storage, + quantization, + payloads, + offsets, + centroids, + is_residual, + } +} + +fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> { + if lists.is_empty() || n == 0 { + return Vec::new(); } + let n = n.min(lists.len()); + lists.select_nth_unstable_by(n - 1, |(x, _), (y, _)| f32::total_cmp(x, y)); + lists.truncate(n); + lists.sort_by(|(x, _), (y, _)| f32::total_cmp(x, y)); + lists } diff --git a/crates/quantization/src/lib.rs b/crates/quantization/src/lib.rs index b3c1bce03..480e6e184 100644 --- a/crates/quantization/src/lib.rs +++ b/crates/quantization/src/lib.rs @@ -230,6 +230,14 @@ impl Quantization { } } + pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { + match &*self.train { + Quantizer::Trivial(x) => x.project(vector), + Quantizer::Scalar(x) => x.project(vector), + Quantizer::Product(x) => x.project(vector), + } + } + pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> QuantizationPreprocessed { match &*self.train { Quantizer::Trivial(x) => QuantizationPreprocessed::Trivial(x.preprocess(lhs)), diff --git a/crates/quantization/src/product/mod.rs b/crates/quantization/src/product/mod.rs index 63804af51..fcb0cffb9 100644 --- a/crates/quantization/src/product/mod.rs +++ b/crates/quantization/src/product/mod.rs @@ -8,6 +8,7 @@ use base::distance::Distance; use base::index::*; use base::operator::*; use base::search::*; +use base::vector::VectorBorrowed; use base::vector::VectorOwned; use common::sample::sample; use common::vec2::Vec2; @@ -100,6 +101,10 @@ impl ProductQuantizer { codes } + pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { + vector.own() + } + pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> O::QuantizationPreprocessed { O::product_quantization_preprocess( self.dims, diff --git a/crates/quantization/src/scalar/mod.rs b/crates/quantization/src/scalar/mod.rs index 1dad72f0e..dcf722beb 100644 --- a/crates/quantization/src/scalar/mod.rs +++ b/crates/quantization/src/scalar/mod.rs @@ -93,6 +93,10 @@ impl ScalarQuantizer { codes } + pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { + vector.own() + } + pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> O::QuantizationPreprocessed { O::scalar_quantization_preprocess(self.dims, self.bits, &self.max, &self.min, lhs) } diff --git a/crates/quantization/src/trivial/mod.rs b/crates/quantization/src/trivial/mod.rs index a3ee1fb7b..b18f8ab19 100644 --- a/crates/quantization/src/trivial/mod.rs +++ b/crates/quantization/src/trivial/mod.rs @@ -8,6 +8,7 @@ use base::distance::Distance; use base::index::*; use base::operator::*; use base::search::*; +use base::vector::VectorBorrowed; use serde::Deserialize; use serde::Serialize; use std::cmp::Reverse; @@ -34,6 +35,10 @@ impl TrivialQuantizer { } } + pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { + vector.own() + } + pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> O::TrivialQuantizationPreprocessed { O::trivial_quantization_preprocess(lhs) } diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index ac39f44f9..67c03fb2a 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -29,7 +29,7 @@ pub struct Rabitq { quantization: Quantization, payloads: MmapArray, offsets: Json>, - centroids: Json>, + projected_centroids: Json>, projection: Json>>, } @@ -70,14 +70,15 @@ impl Rabitq { ) -> Box + 'a> { let projected_query = O::proj(&self.projection, O::cast(vector)); let lists = select( - k_means_lookup_many(&projected_query, &self.centroids), + k_means_lookup_many(&projected_query, &self.projected_centroids), opts.rabitq_nprobe as usize, ); let mut heap = Vec::new(); for &(_, i) in lists.iter() { - let preprocessed = self - .quantization - .preprocess(&O::residual(&projected_query, &self.centroids[(i,)])); + let preprocessed = self.quantization.preprocess(&O::residual( + &projected_query, + &self.projected_centroids[(i,)], + )); let start = self.offsets[i]; let end = self.offsets[i + 1]; self.quantization.push_batch( @@ -188,13 +189,16 @@ fn from_nothing( (0..collection.len()).map(|i| collection.payload(i)), ); let offsets = Json::create(path.as_ref().join("offsets"), offsets); - let centroids = Json::create(path.as_ref().join("centroids"), projected_centroids); + let projected_centroids = Json::create( + path.as_ref().join("projected_centroids"), + projected_centroids, + ); let projection = Json::create(path.as_ref().join("projection"), projection); Rabitq { storage, payloads, offsets, - centroids, + projected_centroids, quantization, projection, } @@ -205,14 +209,14 @@ fn open(path: impl AsRef) -> Rabitq { let quantization = Quantization::open(path.as_ref().join("quantization")); let payloads = MmapArray::open(path.as_ref().join("payloads")); let offsets = Json::open(path.as_ref().join("offsets")); - let centroids = Json::open(path.as_ref().join("centroids")); + let projected_centroids = Json::open(path.as_ref().join("projected_centroids")); let projection = Json::open(path.as_ref().join("projection")); Rabitq { storage, quantization, payloads, offsets, - centroids, + projected_centroids, projection, } } From 6a9fcf904d599943fd9d1dff57de3ffb465a38b8 Mon Sep 17 00:00:00 2001 From: usamoi Date: Mon, 2 Sep 2024 12:21:20 +0800 Subject: [PATCH 2/2] fix: use residual vector for encoding Signed-off-by: usamoi --- crates/ivf/src/lib.rs | 6 +++++- crates/quantization/src/lib.rs | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/crates/ivf/src/lib.rs b/crates/ivf/src/lib.rs index fff65c756..75ad9e0e6 100644 --- a/crates/ivf/src/lib.rs +++ b/crates/ivf/src/lib.rs @@ -121,12 +121,16 @@ fn from_nothing( rayon::check(); let centroids = k_means(nlist as usize, samples, true, spherical_centroids, false); rayon::check(); + let fa = (0..collection.len()) + .into_par_iter() + .map(|i| k_means_lookup(O::interpret(collection.vector(i)), ¢roids)) + .collect::>(); let ls = (0..collection.len()) .into_par_iter() .fold( || vec![Vec::new(); nlist as usize], |mut state, i| { - state[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); + state[fa[i as usize]].push(i); state }, ) diff --git a/crates/quantization/src/lib.rs b/crates/quantization/src/lib.rs index 480e6e184..a44f3cc26 100644 --- a/crates/quantization/src/lib.rs +++ b/crates/quantization/src/lib.rs @@ -22,6 +22,7 @@ use base::distance::Distance; use base::index::*; use base::operator::*; use base::search::*; +use base::vector::VectorOwned; use common::json::Json; use common::mmap_array::MmapArray; use reranker::graph::GraphReranker; @@ -112,8 +113,8 @@ impl Quantization { Box::new(std::iter::empty()) as Box> } Quantizer::Scalar(x) => Box::new((0..vectors.len()).flat_map(|i| { - let vector = vectors.vector(i); - let codes = x.encode(vector); + let vector = transform(vectors.vector(i)); + let codes = x.encode(vector.as_borrowed()); let bytes = x.bytes(); match x.bits() { 1 => InfiniteByteChunks::new(codes.into_iter()) @@ -133,8 +134,8 @@ impl Quantization { } })), Quantizer::Product(x) => Box::new((0..vectors.len()).flat_map(|i| { - let vector = vectors.vector(i); - let codes = x.encode(vector); + let vector = transform(vectors.vector(i)); + let codes = x.encode(vector.as_borrowed()); let bytes = x.bytes(); match x.bits() { 1 => InfiniteByteChunks::new(codes.into_iter()) @@ -170,7 +171,10 @@ impl Quantization { let n = vectors.len(); let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { let id = BLOCK_SIZE * block + i as u32; - x.encode(vectors.vector(std::cmp::min(id, n - 1))) + x.encode( + transform(vectors.vector(std::cmp::min(id, n - 1))) + .as_borrowed(), + ) }); pack(width, raw) })) as Box> @@ -186,7 +190,10 @@ impl Quantization { let n = vectors.len(); let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { let id = BLOCK_SIZE * block + i as u32; - x.encode(vectors.vector(std::cmp::min(id, n - 1))) + x.encode( + transform(vectors.vector(std::cmp::min(id, n - 1))) + .as_borrowed(), + ) }); pack(width, raw) })) as Box>