-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #34 from Jaybro/kd_forest
kd_forest rework
- Loading branch information
Showing
30 changed files
with
731 additions
and
414 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
add_executable(kd_forest kd_forest.cpp) | ||
set_default_target_properties(kd_forest) | ||
target_link_libraries(kd_forest PUBLIC pico_toolshed pico_understory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#include <iostream> | ||
#include <pico_toolshed/format/format_bin.hpp> | ||
#include <pico_toolshed/scoped_timer.hpp> | ||
#include <pico_tree/array_traits.hpp> | ||
#include <pico_tree/kd_tree.hpp> | ||
#include <pico_tree/vector_traits.hpp> | ||
#include <pico_understory/kd_forest.hpp> | ||
|
||
#include "mnist.hpp" | ||
#include "sift.hpp" | ||
|
||
// A KdForest takes roughly forest_size times longer to build compared to | ||
// building a KdTree. However, the KdForest is usually a lot faster with queries | ||
// in high dimensions with the added trade-off that the exact nearest neighbor | ||
// may not be found. | ||
template <typename Dataset> | ||
void RunDataset( | ||
std::size_t tree_max_leaf_size, | ||
std::size_t forest_size, | ||
std::size_t forest_max_leaf_size, | ||
std::size_t forest_max_leaves_visited) { | ||
using Point = typename Dataset::PointType; | ||
using Space = std::reference_wrapper<std::vector<Point>>; | ||
using Scalar = typename Point::value_type; | ||
|
||
auto train = Dataset::ReadTrain(); | ||
auto test = Dataset::ReadTest(); | ||
std::size_t count = test.size(); | ||
std::vector<pico_tree::Neighbor<int, Scalar>> nns(count); | ||
std::string fn_nns_gt = Dataset::kDatasetName + "_nns_gt.bin"; | ||
|
||
if (!std::filesystem::exists(fn_nns_gt)) { | ||
std::cout << "Creating " << fn_nns_gt | ||
<< " using the KdTree. Be *very* patient." << std::endl; | ||
|
||
auto kd_tree = [&train, &tree_max_leaf_size]() { | ||
ScopedTimer t0("kd_tree build"); | ||
return pico_tree::KdTree<Space>(train, tree_max_leaf_size); | ||
}(); | ||
|
||
{ | ||
ScopedTimer t1("kd_tree query"); | ||
for (std::size_t i = 0; i < nns.size(); ++i) { | ||
kd_tree.SearchNn(test[i], nns[i]); | ||
} | ||
} | ||
|
||
pico_tree::WriteBin(fn_nns_gt, nns); | ||
} else { | ||
pico_tree::ReadBin(fn_nns_gt, nns); | ||
std::cout << "KdTree not created. Read " << fn_nns_gt << " instead." | ||
<< std::endl; | ||
} | ||
|
||
std::size_t equal = 0; | ||
{ | ||
auto rkd_tree = [&train, &forest_max_leaf_size, &forest_size]() { | ||
ScopedTimer t0("kd_forest build"); | ||
return pico_tree::KdForest<Space>( | ||
train, forest_max_leaf_size, forest_size); | ||
}(); | ||
|
||
ScopedTimer t1("kd_forest query"); | ||
pico_tree::Neighbor<int, Scalar> nn; | ||
for (std::size_t i = 0; i < nns.size(); ++i) { | ||
rkd_tree.SearchNn(test[i], forest_max_leaves_visited, nn); | ||
|
||
if (nns[i].index == nn.index) { | ||
++equal; | ||
} | ||
} | ||
} | ||
|
||
std::cout << "Precision: " | ||
<< (static_cast<float>(equal) / static_cast<float>(count)) | ||
<< std::endl; | ||
} | ||
|
||
int main() { | ||
// forest_max_leaf_size = 16 | ||
// forest_max_leaves_visited = 16 | ||
// forest_size 8: a precision of around 0.915. | ||
// forest_size 16: a precision of around 0.976. | ||
RunDataset<Mnist>(16, 8, 16, 16); | ||
// forest_max_leaf_size = 32 | ||
// forest_max_leaves_visited = 64 | ||
// forest_size 8: a precision of around 0.884. | ||
// forest_size 16: a precision of around 0.940. | ||
// forest_size 128: out of memory :'( | ||
RunDataset<Sift>(16, 8, 32, 64); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <filesystem> | ||
#include <pico_toolshed/format/format_mnist.hpp> | ||
|
||
template <typename U, typename T, std::size_t N> | ||
std::array<U, N> Cast(std::array<T, N> const& i) { | ||
std::array<U, N> c; | ||
std::transform(i.begin(), i.end(), c.begin(), [](T a) -> U { | ||
return static_cast<U>(a); | ||
}); | ||
return c; | ||
} | ||
|
||
template <typename U, typename T, std::size_t N> | ||
std::vector<std::array<U, N>> Cast(std::vector<std::array<T, N>> const& i) { | ||
std::vector<std::array<U, N>> c; | ||
std::transform( | ||
i.begin(), | ||
i.end(), | ||
std::back_inserter(c), | ||
[](std::array<T, N> const& a) -> std::array<U, N> { return Cast<U>(a); }); | ||
return c; | ||
} | ||
|
||
class Mnist { | ||
private: | ||
using Scalar = float; | ||
using ImageByte = std::array<std::byte, 28 * 28>; | ||
using ImageFloat = std::array<Scalar, 28 * 28>; | ||
|
||
static std::vector<ImageFloat> ReadImages(std::string const& filename) { | ||
if (!std::filesystem::exists(filename)) { | ||
throw std::runtime_error(filename + " doesn't exist."); | ||
} | ||
|
||
std::vector<ImageByte> images_u8; | ||
pico_tree::ReadMnistImages(filename, images_u8); | ||
return Cast<Scalar>(images_u8); | ||
} | ||
|
||
public: | ||
using PointType = ImageFloat; | ||
|
||
static std::string const kDatasetName; | ||
|
||
static std::vector<PointType> ReadTrain() { | ||
std::string fn_images_train = "train-images.idx3-ubyte"; | ||
return ReadImages(fn_images_train); | ||
} | ||
|
||
static std::vector<PointType> ReadTest() { | ||
std::string fn_images_test = "t10k-images.idx3-ubyte"; | ||
return ReadImages(fn_images_test); | ||
} | ||
}; | ||
|
||
std::string const Mnist::kDatasetName = "mnist"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#pragma once | ||
|
||
#include <filesystem> | ||
#include <pico_toolshed/format/format_xvecs.hpp> | ||
|
||
class Sift { | ||
private: | ||
using VectorFloat = std::array<float, 128>; | ||
|
||
static std::vector<VectorFloat> ReadVectors(std::string const& filename) { | ||
if (!std::filesystem::exists(filename)) { | ||
throw std::runtime_error(filename + " doesn't exist."); | ||
} | ||
|
||
std::vector<VectorFloat> vectors; | ||
pico_tree::ReadXvecs(filename, vectors); | ||
return vectors; | ||
} | ||
|
||
public: | ||
using PointType = VectorFloat; | ||
|
||
static std::string const kDatasetName; | ||
|
||
static std::vector<PointType> ReadTrain() { | ||
std::string fn_images_train = "sift_base.fvecs"; | ||
return ReadVectors(fn_images_train); | ||
} | ||
|
||
static std::vector<PointType> ReadTest() { | ||
std::string fn_images_test = "sift_query.fvecs"; | ||
return ReadVectors(fn_images_test); | ||
} | ||
}; | ||
|
||
std::string const Sift::kDatasetName = "sift"; |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.