Skip to content

Commit

Permalink
Merge branch 'next' of github.com:LuisaGroup/LuisaCompute into next
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Jul 27, 2023
2 parents a978321 + 0ecec5e commit f7ed9b1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 57 deletions.
97 changes: 43 additions & 54 deletions include/luisa/vstl/arena_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class ArenaHashMap {
private:
struct PseudoPool {
Arena arena;
LinkNode *elements;
size_t mSize;
PseudoPool(Arena &&arena) : arena(std::forward<Arena>(arena)) {}
PseudoPool(PseudoPool const &) = delete;
PseudoPool(PseudoPool &&) = default;
Expand All @@ -26,7 +28,8 @@ class ArenaHashMap {
template<typename... Args>
requires(std::is_constructible_v<LinkNode, Args && ...>)
LinkNode *create(Args &&...args) {
auto ptr = allocate(sizeof(LinkNode));
auto ptr = elements + mSize;
mSize++;
return new (ptr) LinkNode(std::forward<Args>(args)...);
}
void destroy(LinkNode *ptr) {}
Expand All @@ -37,10 +40,10 @@ class ArenaHashMap {
friend class ArenaHashMap;

private:
LinkNode **ii;
LinkNode *ii;

public:
Iterator(LinkNode **ii) : ii(ii) {}
Iterator(LinkNode *ii) : ii(ii) {}
bool operator==(const Iterator &ite) const noexcept {
return ii == ite.ii;
}
Expand All @@ -51,20 +54,20 @@ class ArenaHashMap {
ii++;
}
NodePair *operator->() const noexcept {
return reinterpret_cast<NodePair *>(&(*ii)->data);
return reinterpret_cast<NodePair *>(&ii->data);
}
NodePair &operator*() const noexcept {
return reinterpret_cast<NodePair &>((*ii)->data);
return reinterpret_cast<NodePair &>(ii->data);
}
};
struct MoveIterator {
friend class ArenaHashMap;

private:
LinkNode **ii;
LinkNode *ii;

public:
MoveIterator(LinkNode **ii) : ii(ii) {}
MoveIterator(LinkNode *ii) : ii(ii) {}
bool operator==(const MoveIterator &ite) const noexcept {
return ii == ite.ii;
}
Expand All @@ -75,10 +78,10 @@ class ArenaHashMap {
ii++;
}
MoveNodePair *operator->() const noexcept {
return &(*ii)->data;
return &ii->data;
}
MoveNodePair &&operator*() const noexcept {
return std::move((*ii)->data);
return std::move(ii->data);
}
};

Expand Down Expand Up @@ -122,26 +125,14 @@ class ArenaHashMap {
}

private:
LinkNode **nodeArray;
PseudoPool pool;
size_t mSize;
Map *nodeVec;
size_t mCapacity;

inline static const Hash hsFunc;
LinkNode *GetNewLinkNode(size_t hashValue, LinkNode *newNode) {
void GetNewLinkNode(size_t hashValue, LinkNode *newNode) {
newNode->hashValue = hashValue;
newNode->arrayIndex = mSize;
nodeArray[mSize] = newNode;
mSize++;
return newNode;
}
void DeleteLinkNode(size_t arrayIndex) {
if (arrayIndex != (mSize - 1)) {
auto ite = nodeArray + (mSize - 1);
(*ite)->arrayIndex = arrayIndex;
nodeArray[arrayIndex] = *ite;
}
mSize--;
newNode->arrayIndex = pool.mSize - 1;
}
static size_t GetPow2Size(size_t capacity) noexcept {
size_t ssize = 1;
Expand All @@ -154,24 +145,24 @@ class ArenaHashMap {
}
void Resize(size_t newCapacity) noexcept {
if (mCapacity >= newCapacity) return;
LinkNode **newNode = reinterpret_cast<LinkNode **>(pool.allocate(sizeof(LinkNode *) * newCapacity * 2));
memcpy(newNode, nodeArray, sizeof(LinkNode *) * mSize);
auto nodeVec = newNode + newCapacity;
memset(nodeVec, 0, sizeof(LinkNode *) * newCapacity);
for (auto node : ptr_range(nodeArray, nodeArray + mSize)) {
size_t hashValue = node->hashValue;
LinkNode *newElements = reinterpret_cast<LinkNode *>(pool.allocate(sizeof(LinkNode) * newCapacity));
nodeVec = reinterpret_cast<Map *>(pool.allocate(sizeof(Map) * newCapacity));
memcpy(newElements, pool.elements, sizeof(LinkNode) * pool.mSize);
memset(nodeVec, 0, sizeof(Map) * newCapacity);
for (auto &node : ptr_range(newElements, pool.mSize)) {
size_t hashValue = node.hashValue;
hashValue = GetHash(hashValue, newCapacity);
Map *targetTree = reinterpret_cast<Map *>(&nodeVec[hashValue]);
targetTree->weak_insert(pool, node);
Map *targetTree = nodeVec + hashValue;
targetTree->weak_insert(pool, &node);
}
nodeArray = newNode;
pool.elements = newElements;
mCapacity = newCapacity;
}
static Index EmptyIndex() noexcept {
return Index(nullptr, nullptr);
}
void TryResize() {
size_t targetCapacity = (size_t)((mSize + 1));
size_t targetCapacity = (size_t)((pool.mSize + 1));
if (targetCapacity < 16) targetCapacity = 16;
if (targetCapacity > mCapacity) {
Resize(GetPow2Size(targetCapacity));
Expand All @@ -180,32 +171,33 @@ class ArenaHashMap {

public:
decltype(auto) begin() const & {
return Iterator(nodeArray);
return Iterator(pool.elements);
}
decltype(auto) begin() && {
return MoveIterator(nodeArray);
return MoveIterator(pool.elements);
}
decltype(auto) end() const & {
return Iterator(nodeArray + mSize);
return Iterator(pool.elements + pool.mSize);
}
decltype(auto) end() && {
return MoveIterator(nodeArray + mSize);
return MoveIterator(pool.elements + pool.mSize);
}
//////////////////Construct & Destruct
ArenaHashMap(size_t capacity, Arena &&arena) noexcept : pool(std::move(arena)) {
if (capacity < 2) capacity = 2;
capacity = GetPow2Size(capacity);
nodeArray = reinterpret_cast<LinkNode **>(this->pool.allocate(sizeof(LinkNode *) * capacity * 2));
memset(nodeArray + capacity, 0, capacity * sizeof(LinkNode *));
pool.elements = reinterpret_cast<LinkNode *>(this->pool.allocate(sizeof(LinkNode) * capacity));
nodeVec = reinterpret_cast<Map *>(this->pool.allocate(sizeof(Map) * capacity));
memset(nodeVec, 0, capacity * sizeof(Map));
mCapacity = capacity;
mSize = 0;
pool.mSize = 0;
}
ArenaHashMap(ArenaHashMap &&map)
: pool(std::move(map.pool)),
mSize(map.mSize),
mCapacity(map.mCapacity),
nodeArray(map.nodeArray) {
map.nodeArray = nullptr;
nodeVec(map.nodeVec),
mCapacity(map.mCapacity) {
map.elements = nullptr;
map.nodeVec = nullptr;
}
ArenaHashMap(ArenaHashMap const &map) = delete;

Expand All @@ -222,9 +214,8 @@ class ArenaHashMap {

size_t hashOriginValue = hsFunc(std::forward<Key>(key));
size_t hashValue = GetHash(hashOriginValue, mCapacity);
auto nodeVec = nodeArray + mCapacity;

Map *map = reinterpret_cast<Map *>(&nodeVec[hashValue]);
Map *map = nodeVec + hashValue;
auto insertResult = map->insert_or_assign(pool, std::forward<Key>(key), std::forward<ARGS>(args)...);
//Add create
if (insertResult.second) {
Expand All @@ -240,9 +231,8 @@ class ArenaHashMap {

size_t hashOriginValue = hsFunc(std::forward<Key>(key));
size_t hashValue = GetHash(hashOriginValue, mCapacity);
auto nodeVec = nodeArray + mCapacity;

Map *map = reinterpret_cast<Map *>(&nodeVec[hashValue]);
Map *map = nodeVec + hashValue;
auto insertResult = map->try_insert(pool, std::forward<Key>(key), std::forward<ARGS>(args)...);
//Add create
if (insertResult.second) {
Expand All @@ -258,9 +248,8 @@ class ArenaHashMap {

size_t hashOriginValue = hsFunc(std::forward<Key>(key));
size_t hashValue = GetHash(hashOriginValue, mCapacity);
auto nodeVec = nodeArray + mCapacity;

Map *map = reinterpret_cast<Map *>(&nodeVec[hashValue]);
Map *map = nodeVec + hashValue;
auto insertResult = map->try_insert(pool, std::forward<Key>(key), std::forward<ARGS>(args)...);
//Add create
if (insertResult.second) {
Expand All @@ -277,14 +266,14 @@ class ArenaHashMap {
Index find(Key &&key) const noexcept {
size_t hashOriginValue = hsFunc(std::forward<Key>(key));
size_t hashValue = GetHash(hashOriginValue, mCapacity);
Map *map = reinterpret_cast<Map *>(nodeArray + mCapacity + hashValue);
Map *map = nodeVec + hashValue;
auto node = map->find(std::forward<Key>(key));
if (node)
return {this, node};
return EmptyIndex();
}
[[nodiscard]] size_t size() const noexcept { return mSize; }
[[nodiscard]] bool empty() const noexcept { return mSize == 0; }
[[nodiscard]] size_t size() const noexcept { return pool.mSize; }
[[nodiscard]] bool empty() const noexcept { return pool.mSize == 0; }
[[nodiscard]] size_t capacity() const noexcept { return mCapacity; }
};
}// namespace vstd
4 changes: 1 addition & 3 deletions include/luisa/vstl/hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ class SmallTreeMap {
}
}
Node *node = pool.create(std::forward<Key>(key), std::forward<Value>(value)...);
node->parent = nullptr;
node->left = nullptr;
node->right = nullptr;
node->color = true;// new node must be red
Expand Down Expand Up @@ -463,12 +462,11 @@ class HashMap {
size_t mCapacity;

inline static const Hash hsFunc;
LinkNode *GetNewLinkNode(size_t hashValue, LinkNode *newNode) {
void GetNewLinkNode(size_t hashValue, LinkNode *newNode) {
newNode->hashValue = hashValue;
newNode->arrayIndex = mSize;
nodeArray[mSize] = newNode;
mSize++;
return newNode;
}
void DeleteLinkNode(size_t arrayIndex) {
if (arrayIndex != (mSize - 1)) {
Expand Down

0 comments on commit f7ed9b1

Please sign in to comment.