Skip to content

Commit

Permalink
Bugfix: Fix a memleak with pollard
Browse files Browse the repository at this point in the history
Before this commit we use Rc cycles to represent the forest's trees, but
this creates some floating Rc's that never gets dropped, yelding a
memory leak. This commit fixes this by using Weak references everywhere
except for a node's children, since the node is meant to be owned by
it's ancestor.

It is ok to upgrade those Weak references because they will never
outlive the node itself.
  • Loading branch information
Davidson-Souza committed Feb 15, 2024
1 parent fdabb1e commit 005107e
Showing 1 changed file with 65 additions and 32 deletions.
97 changes: 65 additions & 32 deletions src/accumulator/pollard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use std::fmt::Formatter;
use std::io::Read;
use std::io::Write;
use std::rc::Rc;
use std::rc::Weak;

use super::node_hash::NodeHash;
use super::proof::Proof;
Expand All @@ -50,6 +51,7 @@ enum NodeType {
Branch,
Leaf,
}

/// A forest node that can either be a leaf or a branch.
#[derive(Clone)]
pub struct Node {
Expand All @@ -58,7 +60,7 @@ pub struct Node {
/// The hash of the stored in this node.
data: Cell<NodeHash>,
/// The parent of this node, if any.
parent: RefCell<Option<Rc<Node>>>,
parent: RefCell<Option<Weak<Node>>>,
/// The left and right children of this node, if any.
left: RefCell<Option<Rc<Node>>>,
/// The left and right children of this node, if any.
Expand All @@ -67,15 +69,18 @@ pub struct Node {
impl Node {
/// Recomputes the hash of all nodes, up to the root.
fn recompute_hashes(&self) {
let left = self.left.borrow().clone();
let right = self.right.borrow().clone();
let left = self.left.borrow();
let right = self.right.borrow();

if let (Some(left), Some(right)) = (left, right) {
if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) {
self.data
.replace(NodeHash::parent_hash(&left.data.get(), &right.data.get()));
}
if let Some(ref mut parent) = *self.parent.borrow_mut() {
parent.recompute_hashes();
if let Some(ref parent) = *self.parent.borrow() {
parent.upgrade().and_then(|p| {
p.recompute_hashes();
Some(())
});
}
}
/// Writes one node to the writer, this method will recursively write all children.
Expand Down Expand Up @@ -107,11 +112,11 @@ impl Node {
#[allow(clippy::type_complexity)]
pub fn read_one<R: std::io::Read>(
reader: &mut R,
) -> std::io::Result<(Rc<Node>, HashMap<NodeHash, Rc<Node>>)> {
) -> std::io::Result<(Rc<Node>, HashMap<NodeHash, Weak<Node>>)> {
fn _read_one<R: std::io::Read>(
ancestor: Option<Rc<Node>>,
reader: &mut R,
index: &mut HashMap<NodeHash, Rc<Node>>,
index: &mut HashMap<NodeHash, Weak<Node>>,
) -> std::io::Result<Rc<Node>> {
let mut data = [0u8; 32];
let mut ty = [0u8; 8];
Expand All @@ -127,17 +132,17 @@ impl Node {
let leaf = Rc::new(Node {
ty,
data: Cell::new(data.into()),
parent: RefCell::new(ancestor),
parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))),
left: RefCell::new(None),
right: RefCell::new(None),
});
index.insert(leaf.data.get(), leaf.clone());
index.insert(leaf.data.get(), Rc::downgrade(&leaf));
return Ok(leaf);
}
let node = Rc::new(Node {
ty: NodeType::Branch,
data: Cell::new(data.into()),
parent: RefCell::new(ancestor),
parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))),
left: RefCell::new(None),
right: RefCell::new(None),
});
Expand All @@ -149,11 +154,11 @@ impl Node {
node.left
.borrow()
.as_ref()
.map(|l| l.parent.replace(Some(node.clone())));
.map(|l| l.parent.replace(Some(Rc::downgrade(&node))));
node.right
.borrow()
.as_ref()
.map(|r| r.parent.replace(Some(node.clone())));
.map(|r| r.parent.replace(Some(Rc::downgrade(&node))));

Ok(node)
}
Expand Down Expand Up @@ -184,7 +189,7 @@ pub struct Pollard {
pub leaves: u64,
/// A map of all nodes in the forest, indexed by their hash, this is used to lookup
/// leaves when proving membership.
map: HashMap<NodeHash, Rc<Node>>,
map: HashMap<NodeHash, Weak<Node>>,
}
impl Pollard {
/// Creates a new empty [Pollard].
Expand Down Expand Up @@ -363,13 +368,23 @@ impl Pollard {
fn del(&mut self, targets: &[NodeHash]) -> Result<(), String> {
let mut pos = targets
.iter()
.map(|target| (self.get_pos(self.map.get(target).unwrap()), target))
.flat_map(|target| self.map.get(target))
.flat_map(|target| target.upgrade())
.map(|target| {
(
self.get_pos(self.map.get(&target.data.get()).unwrap()),
target.data.get(),
)
})
.collect::<Vec<_>>();

pos.sort();
let (_, targets): (Vec<u64>, Vec<NodeHash>) = pos.into_iter().unzip();
for target in targets {
match self.map.remove(&target) {
Some(target) => self.del_single(&target),
Some(target) => {
self.del_single(&target.upgrade().unwrap());
}
None => {
return Err(format!("node {} not in the forest", target));
}
Expand All @@ -385,15 +400,21 @@ impl Pollard {
.collect::<Vec<_>>();
proof.verify(del_hashes, &roots, self.leaves)
}
fn get_pos(&self, node: &Rc<Node>) -> u64 {
fn get_pos(&self, node: &Weak<Node>) -> u64 {
// This indicates whether the node is a left or right child at each level
// When we go down the tree, we can use the indicator to know which
// child to take.
let mut left_child_indicator = 0_u64;
let mut rows_to_top = 0;
let mut node = node.clone();
let mut node = node.upgrade().unwrap();
while let Some(parent) = node.parent.clone().into_inner() {
let parent_left = parent.left.borrow().as_ref().unwrap().clone();
let parent_left = parent
.upgrade()
.map(|parent| parent.left.clone().into_inner())
.flatten()
.unwrap()
.clone();

// If the current node is a left child, we left-shift the indicator
// and leave the LSB as 0
if parent_left.get_data() == node.get_data() {
Expand All @@ -405,7 +426,7 @@ impl Pollard {
left_child_indicator |= 1;
}
rows_to_top += 1;
node = parent.clone();
node = parent.upgrade().unwrap();
}
let mut root_idx = self.roots.len() - 1;
let forest_rows = tree_rows(self.leaves);
Expand Down Expand Up @@ -437,11 +458,11 @@ impl Pollard {
}
pos
}
fn del_single(&mut self, node: &Node) {
fn del_single(&mut self, node: &Node) -> Option<()> {
let parent = node.parent.borrow();
// Deleting a root
let parent = match *parent {
Some(ref node) => node,
Some(ref node) => node.upgrade()?,
None => {
let pos = self.roots.iter().position(|x| x.data == node.data).unwrap();
self.roots[pos] = Rc::new(Node {
Expand All @@ -451,11 +472,13 @@ impl Pollard {
left: RefCell::new(None),
right: RefCell::new(None),
});
return;
return None;
}
};

let me = parent.left.borrow();
// Can unwrap because we know the sibling exists
let sibling = if parent.left.borrow().as_ref().unwrap().data == node.data {
let sibling = if me.as_deref()?.data == node.data {
parent.right.borrow().clone()
} else {
parent.left.borrow().clone()
Expand All @@ -464,7 +487,7 @@ impl Pollard {
let grandparent = parent.parent.borrow().clone();
sibling.parent.replace(grandparent.clone());

if let Some(ref grandparent) = grandparent {
if let Some(ref grandparent) = grandparent.and_then(|g| g.upgrade()) {
if grandparent.left.borrow().clone().as_ref().unwrap().data == parent.data {
grandparent.left.replace(Some(sibling.clone()));
} else {
Expand All @@ -480,6 +503,8 @@ impl Pollard {
self.roots[pos] = sibling.clone();
}
};

Some(())
}
fn add_single(&mut self, value: NodeHash) {
let mut node: Rc<Node> = Rc::new(Node {
Expand All @@ -489,7 +514,7 @@ impl Pollard {
left: RefCell::new(None),
right: RefCell::new(None),
});
self.map.insert(value, node.clone());
self.map.insert(value, Rc::downgrade(&node));
let mut leaves = self.leaves;
while leaves & 1 != 0 {
let root = self.roots.pop().unwrap();
Expand All @@ -504,8 +529,8 @@ impl Pollard {
left: RefCell::new(Some(root.clone())),
right: RefCell::new(Some(node.clone())),
});
root.parent.replace(Some(new_node.clone()));
node.parent.replace(Some(new_node.clone()));
root.parent.replace(Some(Rc::downgrade(&new_node)));
node.parent.replace(Some(Rc::downgrade(&new_node)));

node = new_node;
leaves >>= 1;
Expand Down Expand Up @@ -595,6 +620,7 @@ impl Display for Pollard {
#[cfg(test)]
mod test {
use std::convert::TryFrom;
use std::rc::Rc;
use std::str::FromStr;
use std::vec;

Expand Down Expand Up @@ -847,7 +873,10 @@ mod test {
fn test_get_pos() {
macro_rules! test_get_pos {
($p:ident, $pos:literal) => {
assert_eq!($p.get_pos(&$p.grab_node($pos).unwrap().0), $pos);
assert_eq!(
$p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)),
$pos
);
};
}
let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
Expand All @@ -867,13 +896,17 @@ mod test {
test_get_pos!(p, 11);
test_get_pos!(p, 12);

assert_eq!(p.get_pos(&p.get_roots()[0].clone()), 28);
assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28);
assert_eq!(
p.get_pos(&p.get_roots()[0].left.borrow().clone().unwrap()),
p.get_pos(&Rc::downgrade(
p.get_roots()[0].left.borrow().as_ref().unwrap()
)),
24
);
assert_eq!(
p.get_pos(&p.get_roots()[0].right.borrow().clone().unwrap()),
p.get_pos(&Rc::downgrade(
p.get_roots()[0].right.borrow().as_ref().unwrap()
)),
25
);
}
Expand Down

0 comments on commit 005107e

Please sign in to comment.