From ce18961d67916e5f780055289eed26e23837ceac Mon Sep 17 00:00:00 2001 From: yukang Date: Thu, 14 Nov 2024 23:36:15 +0800 Subject: [PATCH] code review comments feedback --- src/fiber/graph.rs | 57 +++++++++++++++++++------------------------- src/fiber/history.rs | 54 +++++++++++++++++++---------------------- src/fiber/network.rs | 38 +++++++++++++---------------- 3 files changed, 64 insertions(+), 85 deletions(-) diff --git a/src/fiber/graph.rs b/src/fiber/graph.rs index 6c1d7919..37da7e5a 100644 --- a/src/fiber/graph.rs +++ b/src/fiber/graph.rs @@ -491,11 +491,13 @@ where } } - pub(crate) fn record_payment_success(&mut self, payment_session: &PaymentSession) { + pub(crate) fn record_payment_success(&mut self, mut payment_session: PaymentSession) { let session_route = &payment_session.route.nodes; let mut result = InternalResult::default(); result.succeed_range_pairs(session_route, 0, session_route.len() - 1); self.history.apply_internal_result(result); + payment_session.set_success_status(); + self.store.insert_payment_session(payment_session); } pub(crate) fn record_payment_fail( @@ -507,7 +509,7 @@ where let nodes = &payment_session.route.nodes; let need_to_retry = internal_result.record_payment_fail(nodes, tlc_err); self.history.apply_internal_result(internal_result); - return need_to_retry; + return need_to_retry && payment_session.can_retry(); } #[cfg(test)] @@ -782,13 +784,13 @@ where } let mut current = source_node.node_id; - while let Some(elem) = distances.get(¤t) { - let next_hop = elem.next_hop.as_ref().expect("next_hop is none"); + while let Some(elem) = distances.remove(¤t) { + let (next_pubkey, next_out_point) = elem.next_hop.expect("next_hop is none"); result.push(PathEdge { - target: next_hop.0, - channel_outpoint: next_hop.1.clone(), + target: next_pubkey, + channel_outpoint: next_out_point, }); - current = next_hop.0; + current = next_pubkey; if current == target { break; } @@ -884,32 +886,21 @@ impl SessionRoute { // for a payment route A -> B -> C -> D // the `payment_hops` is [B, C, D], which is a convinent way for onion routing. // here we need to create a session route with source, which is A -> B -> C -> D - pub fn new(source: Pubkey, target: Pubkey, payment_hops: &Vec) -> Self { - let mut router = Self::default(); - let mut current = source; - for hop in payment_hops { - if let Some(key) = hop.next_hop { - router.add_node( - current, - hop.channel_outpoint - .clone() - .expect("expect channel outpoint"), - hop.amount, - ); - current = key; - } - } - assert_eq!(current, target); - router.add_node(target, OutPoint::default(), 0); - router - } - - fn add_node(&mut self, pubkey: Pubkey, channel_outpoint: OutPoint, amount: u128) { - self.nodes.push(SessionRouteNode { - pubkey, - channel_outpoint, - amount, - }); + pub fn new(source: Pubkey, target: Pubkey, payment_hops: &[PaymentHopData]) -> Self { + let nodes = std::iter::once(source) + .chain( + payment_hops + .iter() + .map(|hop| hop.next_hop.clone().unwrap_or(target)), + ) + .zip(payment_hops) + .map(|(pubkey, hop)| SessionRouteNode { + pubkey, + channel_outpoint: hop.channel_outpoint.clone().unwrap_or_default(), + amount: hop.amount, + }) + .collect(); + Self { nodes } } } diff --git a/src/fiber/history.rs b/src/fiber/history.rs index 49bf5a66..f6617dc5 100644 --- a/src/fiber/history.rs +++ b/src/fiber/history.rs @@ -67,7 +67,7 @@ impl InternalResult { self.add(from, target, current_time(), amount, false); } - pub fn fail_node(&mut self, nodes: &Vec, index: usize) { + pub fn fail_node(&mut self, nodes: &[SessionRouteNode], index: usize) { self.fail_node = Some(nodes[index].pubkey); if index > 0 { self.fail_pair(nodes, index); @@ -77,7 +77,7 @@ impl InternalResult { } } - pub fn fail_pair(&mut self, route: &Vec, index: usize) { + pub fn fail_pair(&mut self, route: &[SessionRouteNode], index: usize) { if index > 0 { let a = route[index - 1].pubkey; let b = route[index].pubkey; @@ -85,7 +85,7 @@ impl InternalResult { } } - pub fn fail_pair_balanced(&mut self, nodes: &Vec, index: usize) { + pub fn fail_pair_balanced(&mut self, nodes: &[SessionRouteNode], index: usize) { if index > 0 { let a = nodes[index - 1].pubkey; let b = nodes[index].pubkey; @@ -94,7 +94,7 @@ impl InternalResult { } } - pub fn succeed_range_pairs(&mut self, nodes: &Vec, start: usize, end: usize) { + pub fn succeed_range_pairs(&mut self, nodes: &[SessionRouteNode], start: usize, end: usize) { for i in start..end { self.add( nodes[i].pubkey, @@ -105,13 +105,13 @@ impl InternalResult { ); } } - pub fn fail_range_pairs(&mut self, nodes: &Vec, start: usize, end: usize) { + pub fn fail_range_pairs(&mut self, nodes: &[SessionRouteNode], start: usize, end: usize) { for index in start.max(1)..=end { self.fail_pair(nodes, index); } } - pub fn record_payment_fail(&mut self, nodes: &Vec, tlc_err: TlcErr) -> bool { + pub fn record_payment_fail(&mut self, nodes: &[SessionRouteNode], tlc_err: TlcErr) -> bool { let mut need_to_retry = true; let error_index = nodes.iter().position(|s| { @@ -271,10 +271,7 @@ where } pub(crate) fn add_result(&mut self, from: Pubkey, target: Pubkey, result: TimedResult) { - self.inner - .entry(from) - .or_insert_with(HashMap::new) - .insert(target, result); + self.inner.entry(from).or_default().insert(target, result); self.save_result(from, target, result); } @@ -285,11 +282,8 @@ where pub(crate) fn load_from_store(&mut self) { let results = self.store.get_payment_history_result(); - for (from, target, result) in results.iter() { - self.inner - .entry(from.clone()) - .or_insert_with(HashMap::new) - .insert(target.clone(), *result); + for (from, target, result) in results.into_iter() { + self.inner.entry(from).or_default().insert(target, result); } } @@ -338,27 +332,27 @@ where } pub(crate) fn apply_internal_result(&mut self, result: InternalResult) { - for ((from, target), pair_result) in result.pairs.iter() { + let InternalResult { pairs, fail_node } = result; + for ((from, target), pair_result) in pairs.into_iter() { self.apply_pair_result( - *from, - *target, + from, + target, pair_result.amount, pair_result.success, pair_result.time, ); } - - if let Some(fail_node) = result.fail_node { - let mut pairs = vec![]; - for (from, target) in self.inner.keys().flat_map(|from| { - self.inner[from] - .keys() - .map(move |target| (from.clone(), target.clone())) - }) { - if from == fail_node || target == fail_node { - pairs.push((from, target)); - } - } + if let Some(fail_node) = fail_node { + let pairs: Vec<(Pubkey, Pubkey)> = self + .inner + .iter() + .flat_map(|(from, targets)| { + targets.keys().filter_map(move |target| { + (*from == fail_node || *target == fail_node) + .then_some((from.clone(), target.clone())) + }) + }) + .collect(); for (from, target) in pairs { self.apply_pair_result(from, target, 0, false, current_time()); } diff --git a/src/fiber/network.rs b/src/fiber/network.rs index b08030a3..45cb9445 100644 --- a/src/fiber/network.rs +++ b/src/fiber/network.rs @@ -2223,9 +2223,7 @@ where self.network_graph .write() .await - .record_payment_success(&payment_session); - payment_session.set_success_status(); - self.store.insert_payment_session(payment_session); + .record_payment_success(payment_session); } RemoveTlcReason::RemoveTlcFail(reason) => { let error_detail = reason.decode().expect("decoded error"); @@ -2235,7 +2233,7 @@ where .write() .await .record_payment_fail(&payment_session, error_detail.clone()); - if payment_session.can_retry() && need_to_retry { + if need_to_retry { let res = self.try_payment_session(state, payment_session).await; if res.is_err() { debug!("Failed to retry payment session: {:?}", res); @@ -2333,37 +2331,33 @@ where hops: Vec, ) -> Result { let session_key = Privkey::from_slice(KeyPair::generate_random_key().as_ref()); - let peeled_packet = match PeeledPaymentOnionPacket::create( - session_key, - hops.clone(), - &Secp256k1::signing_only(), - ) { - Ok(packet) => packet, - Err(e) => { - let err = format!( - "Failed to create onion packet: {:?}, error: {:?}", - payment_data.payment_hash, e - ); - self.set_payment_fail_with_error(payment_session, &err); - return Err(Error::SendPaymentError(err)); - } - }; - let first_channel_outpoint = hops[0] .channel_outpoint .clone() .expect("first hop channel must exist"); - let session_route = + payment_session.route = SessionRoute::new(state.get_public_key(), payment_data.target_pubkey, &hops); let (send, recv) = oneshot::channel::>(); let rpc_reply = RpcReplyPort::from(send); + let peeled_packet = + match PeeledPaymentOnionPacket::create(session_key, hops, &Secp256k1::signing_only()) { + Ok(packet) => packet, + Err(e) => { + let err = format!( + "Failed to create onion packet: {:?}, error: {:?}", + payment_data.payment_hash, e + ); + self.set_payment_fail_with_error(payment_session, &err); + return Err(Error::SendPaymentError(err)); + } + }; let command = SendOnionPacketCommand { packet: peeled_packet.serialize(), previous_tlc: None, }; - payment_session.route = session_route.clone(); + self.handle_send_onion_packet_command(state, command, rpc_reply) .await; match recv.await.expect("msg recv error") {