Skip to content

Commit

Permalink
Merge pull request #7722 from yyforyongyu/fix-payment-stream
Browse files Browse the repository at this point in the history
routing+lnrpc: subscribe payment stream before sending it
  • Loading branch information
guggero authored May 24, 2023
2 parents b9b20ac + 9ae4511 commit bbbf7d3
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 34 deletions.
5 changes: 5 additions & 0 deletions docs/release-notes/release-notes-0.17.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ unlock or create.

* Added ability to use [ENV variables to override `lncli` global flags](https://github.com/lightningnetwork/lnd/pull/7693). Flags will have preference over ENVs.

## Bug Fix

* Make sure payment stream returns all the events by [subscribing it before
sending](https://github.com/lightningnetwork/lnd/pull/7722).

# Contributors (Alphabetical Order)

* Carla Kirk-Cohen
Expand Down
3 changes: 3 additions & 0 deletions itest/lnd_multi-hop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,9 @@ func runMultiHopHtlcAggregation(ht *lntest.HarnessTest,
resp := carol.RPC.AddInvoice(invoice)
ht.CompletePaymentRequests(alice, []string{resp.PaymentRequest})

// Make sure Carol has settled the invoice.
ht.AssertInvoiceSettled(carol, resp.PaymentAddr)

// With the network active, we'll now add a new hodl invoices at both
// Alice's and Carol's end. Make sure the cltv expiry delta is large
// enough, otherwise Bob won't send out the outgoing htlc.
Expand Down
83 changes: 58 additions & 25 deletions lnrpc/routerrpc/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,27 +318,46 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
return err
}

err = s.cfg.Router.SendPaymentAsync(payment)
// Get the payment hash.
payHash := payment.Identifier()

// Init the payment in db.
paySession, shardTracker, err := s.cfg.Router.PreparePayment(payment)
if err != nil {
// Transform user errors to grpc code.
if err == channeldb.ErrPaymentInFlight ||
err == channeldb.ErrAlreadyPaid {
return err
}

log.Debugf("SendPayment async result for payment %x: %v",
payment.Identifier(), err)
// Subscribe to the payment before sending it to make sure we won't
// miss events.
sub, err := s.subscribePayment(payHash)
if err != nil {
return err
}

return status.Error(
codes.AlreadyExists, err.Error(),
)
}
// Send the payment.
err = s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
if err == nil {
// If the payment was sent successfully, we can start tracking
// the events.
return s.trackPayment(
sub, payHash, stream, req.NoInflightUpdates,
)
}

// Otherwise, transform user errors to grpc code.
if errors.Is(err, channeldb.ErrPaymentInFlight) ||
errors.Is(err, channeldb.ErrAlreadyPaid) {

log.Errorf("SendPayment async error for payment %x: %v",
log.Debugf("SendPayment async result for payment %x: %v",
payment.Identifier(), err)

return err
return status.Error(codes.AlreadyExists, err.Error())
}

return s.trackPayment(payment.Identifier(), stream, req.NoInflightUpdates)
log.Errorf("SendPayment async error for payment %x: %v",
payment.Identifier(), err)

return err
}

// EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
Expand Down Expand Up @@ -800,34 +819,48 @@ func getMsatPairValue(msatValue lnwire.MilliSatoshi,
func (s *Server) TrackPaymentV2(request *TrackPaymentRequest,
stream Router_TrackPaymentV2Server) error {

paymentHash, err := lntypes.MakeHash(request.PaymentHash)
payHash, err := lntypes.MakeHash(request.PaymentHash)
if err != nil {
return err
}

log.Debugf("TrackPayment called for payment %v", paymentHash)
log.Debugf("TrackPayment called for payment %v", payHash)

return s.trackPayment(paymentHash, stream, request.NoInflightUpdates)
// Make the subscription.
sub, err := s.subscribePayment(payHash)
if err != nil {
return err
}

return s.trackPayment(sub, payHash, stream, request.NoInflightUpdates)
}

// trackPayment writes payment status updates to the provided stream.
func (s *Server) trackPayment(identifier lntypes.Hash,
stream Router_TrackPaymentV2Server, noInflightUpdates bool) error {
// subscribePayment subscribes to the payment updates for the given payment
// hash.
func (s *Server) subscribePayment(identifier lntypes.Hash) (
routing.ControlTowerSubscriber, error) {

// Make the subscription.
router := s.cfg.RouterBackend

// Subscribe to the outcome of this payment.
subscription, err := router.Tower.SubscribePayment(identifier)
sub, err := router.Tower.SubscribePayment(identifier)

switch {
case err == channeldb.ErrPaymentNotInitiated:
return status.Error(codes.NotFound, err.Error())
return nil, status.Error(codes.NotFound, err.Error())
case err != nil:
return err
return nil, err
}

return sub, nil
}

// trackPayment writes payment status updates to the provided stream.
func (s *Server) trackPayment(subscription routing.ControlTowerSubscriber,
identifier lntypes.Hash, stream Router_TrackPaymentV2Server,
noInflightUpdates bool) error {

// Stream updates to the client.
err = s.trackPaymentStream(
err := s.trackPaymentStream(
stream.Context(), subscription, noInflightUpdates, stream.Send,
)

Expand Down
15 changes: 6 additions & 9 deletions routing/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,7 @@ func (l *LightningPayment) Identifier() [32]byte {
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
*route.Route, error) {

paySession, shardTracker, err := r.preparePayment(payment)
paySession, shardTracker, err := r.PreparePayment(payment)
if err != nil {
return [32]byte{}, nil, err
}
Expand All @@ -2062,11 +2062,8 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,

// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
paySession, shardTracker, err := r.preparePayment(payment)
if err != nil {
return err
}
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
ps PaymentSession, st shards.ShardTracker) error {

// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
Expand All @@ -2079,7 +2076,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {

_, _, err := r.sendPayment(
payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
payment.PayAttemptTimeout, ps, st,
)
if err != nil {
log.Errorf("Payment %x failed: %v",
Expand Down Expand Up @@ -2111,9 +2108,9 @@ func spewPayment(payment *LightningPayment) logClosure {
})
}

// preparePayment creates the payment session and registers the payment with the
// PreparePayment creates the payment session and registers the payment with the
// control tower.
func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
PaymentSession, shards.ShardTracker, error) {

// Before starting the HTLC routing attempt, we'll create a fresh
Expand Down

0 comments on commit bbbf7d3

Please sign in to comment.