diff --git a/.github/workflows/systest.yml b/.github/workflows/systest.yml index 12554074..aa81357e 100644 --- a/.github/workflows/systest.yml +++ b/.github/workflows/systest.yml @@ -92,6 +92,13 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} + - uses: extractions/netrc@v2 + with: + machine: github.com + username: ${{ secrets.GH_ACTION_TOKEN_USER }} + password: ${{ secrets.GH_ACTION_TOKEN }} + if: vars.GOPRIVATE + - name: Push go-spacemesh build to docker hub run: make dockerpush diff --git a/Dockerfile b/Dockerfile index f297d93e..bd1223a3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,7 +44,7 @@ RUN make get-libs COPY go.mod . COPY go.sum . -RUN go mod download +RUN --mount=type=secret,id=mynetrc,dst=/root/.netrc go mod download # Here we copy the rest of the source code COPY . . diff --git a/Makefile b/Makefile index a82ae6e9..eb6e44a8 100644 --- a/Makefile +++ b/Makefile @@ -154,6 +154,7 @@ list-versions: dockerbuild-go: DOCKER_BUILDKIT=1 docker build \ + --secret id=mynetrc,src=$(HOME)/.netrc \ --build-arg VERSION=${VERSION} \ -t go-spacemesh:$(SHA) \ -t $(DOCKER_HUB)/$(DOCKER_IMAGE_REPO):$(DOCKER_IMAGE_VERSION) . @@ -171,7 +172,11 @@ endif .PHONY: dockerpush-only dockerbuild-bs: - DOCKER_BUILDKIT=1 docker build -t go-spacemesh-bs:$(SHA) -t $(DOCKER_HUB)/$(DOCKER_IMAGE_REPO)-bs:$(DOCKER_IMAGE_VERSION) -f ./bootstrap.Dockerfile . + DOCKER_BUILDKIT=1 docker build \ + --secret id=mynetrc,src=$(HOME)/.netrc \ + -t go-spacemesh-bs:$(SHA) \ + -t $(DOCKER_HUB)/$(DOCKER_IMAGE_REPO)-bs:$(DOCKER_IMAGE_VERSION) \ + -f ./bootstrap.Dockerfile . .PHONY: dockerbuild-bs dockerpush-bs: dockerbuild-bs dockerpush-bs-only diff --git a/activation/handler.go b/activation/handler.go index aac0b274..aa653c79 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -213,16 +213,21 @@ func (h *Handler) SyntacticallyValidateDeps( ctx context.Context, atx *types.ActivationTx, ) (*types.VerifiedActivationTx, *mwire.MalfeasanceProof, error) { - var ( - commitmentATX *types.ATXID - err error - ) + var commitmentATX *types.ATXID if atx.PrevATXID == types.EmptyATXID { if err := h.validateInitialAtx(ctx, atx); err != nil { return nil, nil, err } - commitmentATX = atx.CommitmentATX + commitmentATX = atx.CommitmentATX // checked to be non-nil in syntactic validation } else { + prev, err := atxs.Get(h.cdb, atx.PrevATXID) + if err != nil { + return nil, nil, fmt.Errorf("prev atx for %s not found: %w", atx.PrevATXID, err) + } + if prev.SmesherID != atx.SmesherID { + return nil, nil, fmt.Errorf("prev atx smesher id mismatch: %s != %s", prev.SmesherID, atx.SmesherID) + } + commitmentATX, err = h.getCommitmentAtx(atx) if err != nil { return nil, nil, fmt.Errorf("commitment atx for %s not found: %w", atx.SmesherID, err) @@ -403,76 +408,179 @@ func (h *Handler) cacheAtx(ctx context.Context, atx *types.ActivationTxHeader, n return nil } -// storeAtx stores an ATX and notifies subscribers of the ATXID. -func (h *Handler) storeAtx(ctx context.Context, atx *types.VerifiedActivationTx) (*mwire.MalfeasanceProof, error) { - var nonce *types.VRFPostIndex - malicious, err := h.cdb.IsMalicious(atx.SmesherID) +// checkDoublePublish verifies if a node has already published an ATX in the same epoch. +func (h *Handler) checkDoublePublish( + ctx context.Context, + tx sql.Executor, + atx *types.VerifiedActivationTx, +) (*mwire.MalfeasanceProof, error) { + prev, err := atxs.GetByEpochAndNodeID(tx, atx.PublishEpoch, atx.SmesherID) + if err != nil && !errors.Is(err, sql.ErrNotFound) { + return nil, err + } + + // do ID check to be absolutely sure. + if prev == nil || prev.ID() == atx.ID() { + return nil, nil + } + if _, ok := h.signers[atx.SmesherID]; ok { + // if we land here we tried to publish 2 ATXs in the same epoch + // don't punish ourselves but fail validation and thereby the handling of the incoming ATX + return nil, fmt.Errorf("%s already published an ATX in epoch %d", atx.SmesherID.ShortString(), atx.PublishEpoch) + } + + var atxProof mwire.AtxProof + for i, a := range []*types.VerifiedActivationTx{prev, atx} { + atxProof.Messages[i] = mwire.AtxProofMsg{ + InnerMsg: types.ATXMetadata{ + PublishEpoch: a.PublishEpoch, + MsgHash: wire.ActivationTxToWireV1(a.ActivationTx).HashInnerBytes(), + }, + SmesherID: a.SmesherID, + Signature: a.Signature, + } + } + proof := &mwire.MalfeasanceProof{ + Layer: atx.PublishEpoch.FirstLayer(), + Proof: mwire.Proof{ + Type: mwire.MultipleATXs, + Data: &atxProof, + }, + } + encoded, err := codec.Encode(proof) if err != nil { - return nil, fmt.Errorf("checking if node is malicious: %w", err) + h.log.With().Panic("failed to encode malfeasance proof", log.Err(err)) } - var proof *mwire.MalfeasanceProof - if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { - if malicious { - if err := atxs.Add(tx, atx); err != nil && !errors.Is(err, sql.ErrObjectExists) { - return fmt.Errorf("add atx to db: %w", err) - } - return nil + if err := identities.SetMalicious(tx, atx.SmesherID, encoded, time.Now()); err != nil { + return nil, fmt.Errorf("add malfeasance proof: %w", err) + } + + h.log.WithContext(ctx).With().Warning("smesher produced more than one atx in the same epoch", + log.Stringer("smesher", atx.SmesherID), + log.Object("prev", prev), + log.Object("curr", atx), + ) + + return proof, nil +} + +// checkWrongPrevAtx verifies if the previous ATX referenced in the ATX is correct. +func (h *Handler) checkWrongPrevAtx( + ctx context.Context, + tx sql.Executor, + atx *types.VerifiedActivationTx, +) (*mwire.MalfeasanceProof, error) { + prevID, err := atxs.PrevIDByNodeID(tx, atx.SmesherID, atx.PublishEpoch) + if err != nil && !errors.Is(err, sql.ErrNotFound) { + return nil, fmt.Errorf("get last atx by node id: %w", err) + } + if prevID == atx.PrevATXID { + return nil, nil + } + + if _, ok := h.signers[atx.SmesherID]; ok { + // if we land here we tried to publish an ATX with a wrong prevATX + h.log.WithContext(ctx).With().Warning( + "Node produced an ATX with a wrong prevATX. This can happened when the node wasn't synced when "+ + "registering at PoET", + log.Stringer("smesher", atx.SmesherID), + log.ShortStringer("expected", prevID), + log.ShortStringer("actual", atx.PrevATXID), + ) + return nil, fmt.Errorf("%s referenced incorrect previous ATX", atx.SmesherID.ShortString()) + } + + // check if atx.PrevATXID is actually the last published ATX by the same node + prev, err := atxs.Get(tx, prevID) + if err != nil { + return nil, fmt.Errorf("get prev atx: %w", err) + } + + // if atx references a previous ATX that is not the last ATX by the same node, there must be at least one + // atx published between prevATX and the current epoch + var atx2 *types.VerifiedActivationTx + pubEpoch := h.clock.CurrentLayer().GetEpoch() + for pubEpoch > prev.PublishEpoch { + id, err := atxs.PrevIDByNodeID(tx, atx.SmesherID, pubEpoch) + if err != nil { + return nil, fmt.Errorf("get prev atx id by node id: %w", err) } - prev, err := atxs.GetByEpochAndNodeID(tx, atx.PublishEpoch, atx.SmesherID) - if err != nil && !errors.Is(err, sql.ErrNotFound) { - return err + atx2, err = atxs.Get(tx, id) + if err != nil { + return nil, fmt.Errorf("get prev atx: %w", err) } - // do ID check to be absolutely sure. - if prev != nil && prev.ID() != atx.ID() { - if _, ok := h.signers[atx.SmesherID]; ok { - // if we land here we tried to publish 2 ATXs in the same epoch - // don't punish ourselves but fail validation and thereby the handling of the incoming ATX - return fmt.Errorf("%s already published an ATX in epoch %d", atx.SmesherID.ShortString(), - atx.PublishEpoch, - ) - } - - var atxProof mwire.AtxProof - for i, a := range []*types.VerifiedActivationTx{prev, atx} { - atxProof.Messages[i] = mwire.AtxProofMsg{ - InnerMsg: types.ATXMetadata{ - PublishEpoch: a.PublishEpoch, - MsgHash: wire.ActivationTxToWireV1(a.ActivationTx).HashInnerBytes(), - }, - SmesherID: a.SmesherID, - Signature: a.Signature, - } - } - proof = &mwire.MalfeasanceProof{ - Layer: atx.PublishEpoch.FirstLayer(), - Proof: mwire.Proof{ - Type: mwire.MultipleATXs, - Data: &atxProof, - }, - } - encoded, err := codec.Encode(proof) - if err != nil { - h.log.With().Panic("failed to encode malfeasance proof", log.Err(err)) - } - if err := identities.SetMalicious(tx, atx.SmesherID, encoded, time.Now()); err != nil { - return fmt.Errorf("add malfeasance proof: %w", err) - } - - h.log.WithContext(ctx).With().Warning("smesher produced more than one atx in the same epoch", - log.Stringer("smesher", atx.SmesherID), - log.Object("prev", prev), - log.Object("curr", atx), - ) + if atx.ID() != atx2.ID() && atx.PrevATXID == atx2.PrevATXID { + // found an ATX that points to the same previous ATX + break } + pubEpoch = atx2.PublishEpoch + } - nonce, err = atxs.AddGettingNonce(tx, atx) - if err != nil && !errors.Is(err, sql.ErrObjectExists) { - return fmt.Errorf("add atx to db: %w", err) + if atx2 == nil || atx2.PrevATXID != atx.PrevATXID { + // something went wrong, we couldn't find an ATX that points to the same previous ATX + // this should never happen since we are checking in other places that all ATXs from the same node + // form a chain + return nil, errors.New("failed double previous check: could not find an ATX with same previous ATX") + } + + proof := &mwire.MalfeasanceProof{ + Layer: atx.PublishEpoch.FirstLayer(), + Proof: mwire.Proof{ + Type: mwire.InvalidPrevATX, + Data: &mwire.InvalidPrevATXProof{ + Atx1: *wire.ActivationTxToWireV1(atx.ActivationTx), + Atx2: *wire.ActivationTxToWireV1(atx2.ActivationTx), + }, + }, + } + + if err := identities.SetMalicious(tx, atx.SmesherID, codec.MustEncode(proof), time.Now()); err != nil { + return nil, fmt.Errorf("add malfeasance proof: %w", err) + } + + h.log.WithContext(ctx).With().Warning("smesher referenced the wrong previous in published ATX", + log.Stringer("smesher", atx.SmesherID), + log.ShortStringer("expected", prevID), + log.ShortStringer("actual", atx.PrevATXID), + ) + return proof, nil +} + +func (h *Handler) checkMalicious( + ctx context.Context, + tx *sql.Tx, + atx *types.VerifiedActivationTx, +) (*mwire.MalfeasanceProof, error) { + malicious, err := identities.IsMalicious(tx, atx.SmesherID) + if err != nil { + return nil, fmt.Errorf("checking if node is malicious: %w", err) + } + if malicious { + return nil, nil + } + proof, err := h.checkDoublePublish(ctx, tx, atx) + if proof != nil || err != nil { + return proof, err + } + return h.checkWrongPrevAtx(ctx, tx, atx) +} + +// storeAtx stores an ATX and notifies subscribers of the ATXID. +func (h *Handler) storeAtx(ctx context.Context, atx *types.VerifiedActivationTx) (*mwire.MalfeasanceProof, error) { + var nonce *types.VRFPostIndex + var proof *mwire.MalfeasanceProof + err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + var err1, err2 error + proof, err1 = h.checkMalicious(ctx, tx, atx) + nonce, err2 = atxs.AddGettingNonce(tx, atx) + if err2 != nil && !errors.Is(err2, sql.ErrObjectExists) { + err2 = fmt.Errorf("add atx to db: %w", err2) } - return nil - }); err != nil { + return errors.Join(err1, err2) + }) + if err != nil { return nil, fmt.Errorf("store atx: %w", err) } if nonce == nil { @@ -512,7 +620,7 @@ func (h *Handler) HandleSyncedAtx(ctx context.Context, expHash types.Hash32, pee // HandleGossipAtx handles the atx gossip data channel. func (h *Handler) HandleGossipAtx(ctx context.Context, peer p2p.Peer, msg []byte) error { - proof, err := h.handleAtx(ctx, types.Hash32{}, peer, msg) + proof, err := h.handleAtx(ctx, types.EmptyHash32, peer, msg) if err != nil && !errors.Is(err, errMalformedData) && !errors.Is(err, errKnownAtx) { h.log.WithContext(ctx).With().Warning("failed to process atx gossip", log.Stringer("sender", peer), @@ -621,7 +729,7 @@ func (h *Handler) processATX( return proof, err } - if expHash != (types.Hash32{}) && vAtx.ID().Hash32() != expHash { + if expHash != types.EmptyHash32 && vAtx.ID().Hash32() != expHash { return nil, fmt.Errorf( "%w: atx want %s, got %s", errWrongHash, @@ -637,7 +745,8 @@ func (h *Handler) processATX( events.ReportNewActivation(vAtx) h.log.WithContext(ctx).With().Info( "new atx", log.Inline(vAtx), - log.Bool("malicious", proof != nil)) + log.Bool("malicious", proof != nil), + ) return proof, err } diff --git a/activation/handler_test.go b/activation/handler_test.go index 70fcc007..02135c82 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -708,6 +708,32 @@ func TestHandler_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.SyntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "prev atx declared, but node id is included") }) + + t.Run("prevAtx by different NodeID", func(t *testing.T) { + t.Parallel() + + atxHdlr := newTestHandler(t, goldenATXID) + require.NoError(t, atxs.Add(atxHdlr.cdb, posAtx)) + require.NoError(t, atxs.Add(atxHdlr.cdb, prevAtx)) + + challenge := types.NIPostChallenge{ + Sequence: posAtx.Sequence + 1, + PrevATXID: posAtx.ID(), + PublishEpoch: currentLayer.GetEpoch(), + PositioningATX: posAtx.ID(), + CommitmentATX: nil, + } + nipost := types.NIPost{PostMetadata: &types.PostMetadata{}} + atx := newAtx(challenge, &nipost, 100, types.GenerateAddress([]byte("aaaa"))) + atx.NIPost = newNIPostWithPoet(t, poetRef).NIPost + require.NoError(t, SignAndFinalizeAtx(sig, atx)) + + atxHdlr.mclock.EXPECT().CurrentLayer().Return(currentLayer) + require.NoError(t, atxHdlr.SyntacticallyValidate(context.Background(), atx)) + _, proof, err := atxHdlr.SyntacticallyValidateDeps(context.Background(), atx) + require.ErrorContains(t, err, "prev atx smesher id mismatch") + require.Nil(t, proof) + }) } func TestHandler_ContextuallyValidateAtx(t *testing.T) { @@ -931,7 +957,7 @@ func TestHandler_ProcessAtx(t *testing.T) { types.EmptyATXID, types.EmptyATXID, nil, - types.LayerID(layersPerEpoch).GetEpoch(), + types.EpochID(2), 0, 100, coinbase, @@ -949,6 +975,73 @@ func TestHandler_ProcessAtx(t *testing.T) { proof, err = atxHdlr.processVerifiedATX(context.Background(), atx1) require.NoError(t, err) require.Nil(t, proof) +} + +func TestHandler_ProcessAtx_maliciousIdentity(t *testing.T) { + // Arrange + goldenATXID := types.ATXID{2, 3, 4} + atxHdlr := newTestHandler(t, goldenATXID) + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + require.NoError(t, identities.SetMalicious(atxHdlr.cdb, sig.NodeID(), types.RandomBytes(10), time.Now())) + + coinbase := types.GenerateAddress([]byte("aaaa")) + + // Act & Assert + atx1 := newActivationTx( + t, + sig, + 0, + types.EmptyATXID, + types.EmptyATXID, + nil, + types.EpochID(2), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + withVrfNonce(7), + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err := atxHdlr.processVerifiedATX(context.Background(), atx1) + require.NoError(t, err) + require.Nil(t, proof) +} + +func TestHandler_ProcessAtx_SamePubEpoch(t *testing.T) { + // Arrange + goldenATXID := types.ATXID{2, 3, 4} + atxHdlr := newTestHandler(t, goldenATXID) + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + coinbase := types.GenerateAddress([]byte("aaaa")) + + // Act & Assert + atx1 := newActivationTx( + t, + sig, + 0, + types.EmptyATXID, + types.EmptyATXID, + nil, + types.EpochID(2), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + withVrfNonce(7), + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err := atxHdlr.processVerifiedATX(context.Background(), atx1) + require.NoError(t, err) + require.Nil(t, proof) // another atx for the same epoch is considered malicious atx2 := newActivationTx( @@ -958,7 +1051,7 @@ func TestHandler_ProcessAtx(t *testing.T) { atx1.ID(), atx1.ID(), nil, - types.LayerID(layersPerEpoch+1).GetEpoch(), + types.EpochID(2), 0, 100, coinbase, @@ -986,7 +1079,7 @@ func TestHandler_ProcessAtx(t *testing.T) { require.Equal(t, sig.NodeID(), nodeID) } -func TestHandler_ProcessAtx_OwnNotMalicious(t *testing.T) { +func TestHandler_ProcessAtx_SamePubEpoch_NoSelfIncrimination(t *testing.T) { // Arrange goldenATXID := types.ATXID{2, 3, 4} atxHdlr := newTestHandler(t, goldenATXID) @@ -1005,7 +1098,7 @@ func TestHandler_ProcessAtx_OwnNotMalicious(t *testing.T) { types.EmptyATXID, types.EmptyATXID, nil, - types.LayerID(layersPerEpoch).GetEpoch(), + types.EpochID(2), 0, 100, coinbase, @@ -1032,7 +1125,7 @@ func TestHandler_ProcessAtx_OwnNotMalicious(t *testing.T) { atx1.ID(), atx1.ID(), nil, - types.LayerID(layersPerEpoch+1).GetEpoch(), + types.EpochID(2), 0, 100, coinbase, @@ -1044,7 +1137,173 @@ func TestHandler_ProcessAtx_OwnNotMalicious(t *testing.T) { err, fmt.Sprintf("%s already published an ATX", sig.NodeID().ShortString()), ) + require.Nil(t, proof) // no proof against oneself +} + +func TestHandler_ProcessAtx_SamePrevATX(t *testing.T) { + // Arrange + goldenATXID := types.ATXID{2, 3, 4} + atxHdlr := newTestHandler(t, goldenATXID) + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + coinbase := types.GenerateAddress([]byte("aaaa")) + + // Act & Assert + prevATX := newActivationTx( + t, + sig, + 0, + types.EmptyATXID, + types.EmptyATXID, + nil, + types.EpochID(2), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + withVrfNonce(7), + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err := atxHdlr.processVerifiedATX(context.Background(), prevATX) + require.NoError(t, err) + require.Nil(t, proof) + + // valid first non-initial ATX + atx1 := newActivationTx( + t, + sig, + 1, + prevATX.ID(), + prevATX.ID(), + nil, + types.EpochID(3), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err = atxHdlr.processVerifiedATX(context.Background(), atx1) + require.NoError(t, err) require.Nil(t, proof) + + // second non-initial ATX references prevATX as prevATX + atx2 := newActivationTx( + t, + sig, + 2, + prevATX.ID(), + atx1.ID(), + nil, + types.EpochID(4), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any()) + atxHdlr.mclock.EXPECT().CurrentLayer().Return(types.EpochID(4).FirstLayer()) + proof, err = atxHdlr.processVerifiedATX(context.Background(), atx2) + require.NoError(t, err) + proof.SetReceived(time.Time{}) + nodeID, err := malfeasance.Validate( + context.Background(), + atxHdlr.log, + atxHdlr.cdb, + atxHdlr.edVerifier, + nil, + &mwire.MalfeasanceGossip{ + MalfeasanceProof: *proof, + }, + ) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), nodeID) +} + +func TestHandler_ProcessAtx_SamePrevATX_NoSelfIncrimination(t *testing.T) { + // Arrange + goldenATXID := types.ATXID{2, 3, 4} + atxHdlr := newTestHandler(t, goldenATXID) + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + atxHdlr.Register(sig) + + coinbase := types.GenerateAddress([]byte("aaaa")) + + // Act & Assert + prevATX := newActivationTx( + t, + sig, + 0, + types.EmptyATXID, + types.EmptyATXID, + nil, + types.EpochID(2), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + withVrfNonce(7), + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err := atxHdlr.processVerifiedATX(context.Background(), prevATX) + require.NoError(t, err) + require.Nil(t, proof) + + // valid first non-initial ATX + atx1 := newActivationTx( + t, + sig, + 1, + prevATX.ID(), + prevATX.ID(), + nil, + types.EpochID(3), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + ) + atxHdlr.mbeacon.EXPECT().OnAtx(gomock.Any()) + atxHdlr.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + proof, err = atxHdlr.processVerifiedATX(context.Background(), atx1) + require.NoError(t, err) + require.Nil(t, proof) + + // second non-initial ATX references prevATX as prevATX + atx2 := newActivationTx( + t, + sig, + 2, + prevATX.ID(), + atx1.ID(), + nil, + types.EpochID(4), + 0, + 100, + coinbase, + 100, + &types.NIPost{PostMetadata: &types.PostMetadata{}}, + ) + proof, err = atxHdlr.processVerifiedATX(context.Background(), atx2) + require.ErrorContains(t, + err, + fmt.Sprintf("%s referenced incorrect previous ATX", sig.NodeID().ShortString()), + ) + require.Nil(t, proof) // no proof against oneself } func testHandler_PostMalfeasanceProofs(t *testing.T, synced bool) { @@ -1165,7 +1424,7 @@ func TestHandler_ProcessAtxStoresNewVRFNonce(t *testing.T) { types.EmptyATXID, types.EmptyATXID, nil, - types.LayerID(layersPerEpoch).GetEpoch(), + types.EpochID(2), 0, 100, coinbase, @@ -1192,7 +1451,7 @@ func TestHandler_ProcessAtxStoresNewVRFNonce(t *testing.T) { atx1.ID(), atx1.ID(), nil, - types.LayerID(2*layersPerEpoch).GetEpoch(), + types.EpochID(3), 0, 100, coinbase, diff --git a/activation/verify_state.go b/activation/verify_state.go new file mode 100644 index 00000000..3c9db64b --- /dev/null +++ b/activation/verify_state.go @@ -0,0 +1,88 @@ +package activation + +import ( + "context" + "fmt" + "time" + + "go.uber.org/zap" + + awire "github.com/spacemeshos/go-spacemesh/activation/wire" + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/log" + "github.com/spacemeshos/go-spacemesh/malfeasance/wire" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/identities" +) + +func CheckPrevATXs(ctx context.Context, logger *zap.Logger, db sql.Executor) error { + collisions, err := atxs.PrevATXCollisions(db) + if err != nil { + return fmt.Errorf("get prev ATX collisions: %w", err) + } + + logger.Info("found ATX collisions", zap.Int("count", len(collisions))) + count := 0 + for _, collision := range collisions { + select { + case <-ctx.Done(): + // stop on context cancellation + return ctx.Err() + default: + } + + if collision.NodeID1 != collision.NodeID2 { + logger.Panic( + "unexpected collision", + log.ZShortStringer("NodeID1", collision.NodeID1), + log.ZShortStringer("NodeID2", collision.NodeID2), + log.ZShortStringer("ATX1", collision.ATX1), + log.ZShortStringer("ATX2", collision.ATX2), + ) + } + + malicious, err := identities.IsMalicious(db, collision.NodeID1) + if err != nil { + return fmt.Errorf("get malicious status: %w", err) + } + + if malicious { + // already malicious no need to generate proof + continue + } + + var blob sql.Blob + var atx1 awire.ActivationTxV1 + if err := atxs.LoadBlob(ctx, db, collision.ATX1.Bytes(), &blob); err != nil { + return fmt.Errorf("get blob %s: %w", collision.ATX1.ShortString(), err) + } + codec.MustDecode(blob.Bytes, &atx1) + + var atx2 awire.ActivationTxV1 + if err := atxs.LoadBlob(ctx, db, collision.ATX2.Bytes(), &blob); err != nil { + return fmt.Errorf("get blob %s: %w", collision.ATX2.ShortString(), err) + } + codec.MustDecode(blob.Bytes, &atx2) + + proof := &wire.MalfeasanceProof{ + Layer: atx1.Publish.FirstLayer(), + Proof: wire.Proof{ + Type: wire.InvalidPrevATX, + Data: &wire.InvalidPrevATXProof{ + Atx1: atx1, + Atx2: atx2, + }, + }, + } + + encodedProof := codec.MustEncode(proof) + if err := identities.SetMalicious(db, collision.NodeID1, encodedProof, time.Now()); err != nil { + return fmt.Errorf("add malfeasance proof: %w", err) + } + + count++ + } + logger.Info("created malfeasance proofs", zap.Int("count", count)) + return nil +} diff --git a/activation/verify_state_test.go b/activation/verify_state_test.go new file mode 100644 index 00000000..801cd3c6 --- /dev/null +++ b/activation/verify_state_test.go @@ -0,0 +1,91 @@ +package activation + +import ( + "context" + "math/rand/v2" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/identities" +) + +func Test_CheckPrevATXs(t *testing.T) { + db := sql.InMemory() + logger := zaptest.NewLogger(t) + + // Arrange + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + // create two ATXs with the same PrevATXID + prevATXID := types.RandomATXID() + goldenATXID := types.RandomATXID() + + atx1 := newActivationTx( + t, + sig, + 0, + prevATXID, + goldenATXID, + &goldenATXID, + types.EpochID(2), + 0, + 100, + types.GenerateAddress([]byte("aaaa")), + 100, + nil, + ) + require.NoError(t, atxs.Add(db, atx1)) + + atx2 := newActivationTx( + t, + sig, + 1, + prevATXID, + goldenATXID, + &goldenATXID, + types.EpochID(3), + 0, + 100, + types.GenerateAddress([]byte("aaaa")), + 100, + nil, + ) + require.NoError(t, atxs.Add(db, atx2)) + + // create 100 random ATXs that are not malicious + for i := 0; i < 100; i++ { + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + atx := newActivationTx( + t, + otherSig, + rand.Uint64(), + types.RandomATXID(), + types.RandomATXID(), + nil, + rand.N[types.EpochID](100), + 0, + 100, + types.GenerateAddress([]byte("aaaa")), + rand.Uint32(), + nil, + ) + require.NoError(t, atxs.Add(db, atx)) + } + + // Act + err = CheckPrevATXs(context.Background(), logger, db) + require.NoError(t, err) + + // Assert + malicious, err := identities.IsMalicious(db, sig.NodeID()) + require.NoError(t, err) + require.True(t, malicious) +} diff --git a/activation/wire/wire_v1.go b/activation/wire/wire_v1.go index a0947a83..7d57350f 100644 --- a/activation/wire/wire_v1.go +++ b/activation/wire/wire_v1.go @@ -15,6 +15,8 @@ type ActivationTxV1 struct { SmesherID types.NodeID Signature types.EdSignature + + id types.ATXID } // InnerActivationTxV1 is a set of all of an ATX's fields, except the signature. To generate the ATX signature, this @@ -92,6 +94,18 @@ type ATXMetadataV1 struct { MsgHash types.Hash32 } +func (atx *ActivationTxV1) ID() types.ATXID { + if atx.id == types.EmptyATXID { + atx.id = types.ATXID(atx.HashInnerBytes()) + } + return atx.id +} + +// TODO(mafa): this can be inlined. +func (atx *ActivationTxV1) Smesher() types.NodeID { + return atx.SmesherID +} + func (atx *ActivationTxV1) SignedBytes() []byte { data := codec.MustEncode(&ATXMetadataV1{ Publish: atx.Publish, diff --git a/bootstrap.Dockerfile b/bootstrap.Dockerfile index d5b4bb8f..25a4524d 100644 --- a/bootstrap.Dockerfile +++ b/bootstrap.Dockerfile @@ -6,7 +6,7 @@ COPY Makefile* . COPY go.mod . COPY go.sum . -RUN go mod download +RUN --mount=type=secret,id=mynetrc,dst=/root/.netrc go mod download # copy the rest of the source code COPY . . diff --git a/cmd/root.go b/cmd/root.go index aa0db3e7..4fcd2862 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -79,6 +79,9 @@ func AddFlags(flagSet *pflag.FlagSet, cfg *config.Config) (configPath *string) { flagSet.DurationVar(&cfg.DatabasePruneInterval, "db-prune-interval", cfg.DatabasePruneInterval, "configure interval for database pruning") + flagSet.BoolVar(&cfg.ScanMalfeasantATXs, "scan-malfeasant-atxs", cfg.ScanMalfeasantATXs, + "scan for malfeasant ATXs") + flagSet.BoolVar(&cfg.NoMainOverride, "no-main-override", cfg.NoMainOverride, "force 'nomain' builds to run on the mainnet") diff --git a/common/types/hashes.go b/common/types/hashes.go index c01c748d..5590541d 100644 --- a/common/types/hashes.go +++ b/common/types/hashes.go @@ -20,6 +20,9 @@ const ( var ( hash20T = reflect.TypeOf(Hash20{}) hash32T = reflect.TypeOf(Hash32{}) + + // EmptyHash32 is the zero hash. + EmptyHash32 = Hash32{} ) // Hash32 represents the 32-byte blake3 hash of arbitrary data. diff --git a/common/types/layer.go b/common/types/layer.go index d88dcb59..f760227f 100644 --- a/common/types/layer.go +++ b/common/types/layer.go @@ -18,7 +18,7 @@ var ( effectiveGenesis uint32 // EmptyLayerHash is the layer hash for an empty layer. - EmptyLayerHash = Hash32{} + EmptyLayerHash = EmptyHash32 ) // SetLayersPerEpoch sets global parameter of layers per epoch, all conversions from layer to epoch use this param. diff --git a/config/config.go b/config/config.go index ca176869..27cb0adf 100644 --- a/config/config.go +++ b/config/config.go @@ -124,6 +124,9 @@ type BaseConfig struct { PruneActivesetsFrom types.EpochID `mapstructure:"prune-activesets-from"` + // ScanMalfeasantATXs is a flag to enable scanning for malfeasant ATXs. + ScanMalfeasantATXs bool `mapstructure:"scan-malfeasant-atxs"` + NetworkHRP string `mapstructure:"network-hrp"` // MinerGoodAtxsPercent is a threshold to decide if tortoise activeset should be diff --git a/config/mainnet.go b/config/mainnet.go index b3ebe45f..fa326a8f 100644 --- a/config/mainnet.go +++ b/config/mainnet.go @@ -73,7 +73,8 @@ func MainnetConfig() Config { DatabaseConnections: 16, DatabasePruneInterval: 30 * time.Minute, DatabaseVacuumState: 15, - PruneActivesetsFrom: 12, // starting from epoch 13 activesets below 12 will be pruned + PruneActivesetsFrom: 12, // starting from epoch 13 activesets below 12 will be pruned + ScanMalfeasantATXs: false, // opt-in NetworkHRP: "sm", LayerDuration: 5 * time.Minute, diff --git a/events/events.go b/events/events.go index 7501bf73..30107a73 100644 --- a/events/events.go +++ b/events/events.go @@ -298,6 +298,8 @@ func ToMalfeasancePB(nodeID types.NodeID, mp *wire.MalfeasanceProof, includeProo kind = pb.MalfeasanceProof_MALFEASANCE_HARE case wire.InvalidPostIndex: kind = pb.MalfeasanceProof_MALFEASANCE_POST_INDEX + case wire.InvalidPrevATX: + kind = pb.MalfeasanceProof_MALFEASANCE_INCORRECT_PREV_ATX } result := &pb.MalfeasanceProof{ SmesherId: &pb.SmesherId{Id: nodeID.Bytes()}, diff --git a/go.mod b/go.mod index 9a76abd9..774cc38f 100644 --- a/go.mod +++ b/go.mod @@ -232,3 +232,5 @@ require ( sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect sigs.k8s.io/yaml v1.4.0 // indirect ) + +replace github.com/spacemeshos/api/release/go => github.com/spacemeshos/api-cve-fix/release/go v1.36.1-0.20240429173440-42be53a006d3 diff --git a/go.sum b/go.sum index c98cfbe8..83009309 100644 --- a/go.sum +++ b/go.sum @@ -555,8 +555,8 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= -github.com/spacemeshos/api/release/go v1.37.0 h1:bN6AhSMVSmAShGxUYKwFBfzY3U1XtHezpDjt20dHjBM= -github.com/spacemeshos/api/release/go v1.37.0/go.mod h1:Ed7SdL2YgqNg2SeShEAonW3GTPuuaGzsY5i4bgziCRo= +github.com/spacemeshos/api-cve-fix/release/go v1.36.1-0.20240429173440-42be53a006d3 h1:AXTfy9764T4zye7fk3V2X4K4xVuyrX0X/n8nXZ4YEMg= +github.com/spacemeshos/api-cve-fix/release/go v1.36.1-0.20240429173440-42be53a006d3/go.mod h1:Ed7SdL2YgqNg2SeShEAonW3GTPuuaGzsY5i4bgziCRo= github.com/spacemeshos/economics v0.1.3 h1:ACkq3mTebIky4Zwbs9SeSSRZrUCjU/Zk0wq9Z0BTh2A= github.com/spacemeshos/economics v0.1.3/go.mod h1:FH7u0FzTIm6Kpk+X5HOZDvpkgNYBKclmH86rVwYaDAo= github.com/spacemeshos/fixed v0.1.1 h1:N1y4SUpq1EV+IdJrWJwUCt1oBFzeru/VKVcBsvPc2Fk= diff --git a/malfeasance/handler.go b/malfeasance/handler.go index c48ec8e9..23421773 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -86,7 +86,7 @@ func (h *Handler) HandleSyncedMalfeasanceProof( nodeID, err := h.validateAndSave(ctx, &wire.MalfeasanceGossip{MalfeasanceProof: p}) if err == nil && types.Hash32(nodeID) != expHash { return fmt.Errorf( - "%w: malfesance proof want %s, got %s", + "%w: malfeasance proof want %s, got %s", errWrongHash, expHash.ShortString(), nodeID.ShortString(), @@ -187,6 +187,9 @@ func Validate( case wire.InvalidPostIndex: proof := p.MalfeasanceProof.Proof.Data.(*wire.InvalidPostIndexProof) // guaranteed to work by scale func nodeID, err = validateInvalidPostIndex(ctx, logger, cdb, edVerifier, postVerifier, proof) + case wire.InvalidPrevATX: + proof := p.MalfeasanceProof.Proof.Data.(*wire.InvalidPrevATXProof) // guaranteed to work by scale func + nodeID, err = validateInvalidPrevATX(ctx, cdb, edVerifier, proof) default: return nodeID, fmt.Errorf("%w: unknown malfeasance type", errInvalidProof) } @@ -211,6 +214,8 @@ func updateMetrics(tp wire.Proof) { numProofsBallot.Inc() case wire.InvalidPostIndex: numProofsPostIndex.Inc() + case wire.InvalidPrevATX: + numProofsPrevATX.Inc() } } @@ -377,7 +382,8 @@ func validateMultipleBallots( return types.EmptyNodeID, errors.New("invalid ballot malfeasance proof") } -func validateInvalidPostIndex(ctx context.Context, +func validateInvalidPostIndex( + ctx context.Context, logger log.Log, db sql.Executor, edVerifier SigVerifier, @@ -415,3 +421,30 @@ func validateInvalidPostIndex(ctx context.Context, numInvalidProofsPostIndex.Inc() return types.EmptyNodeID, errors.New("invalid post index malfeasance proof - POST is valid") } + +func validateInvalidPrevATX( + ctx context.Context, + db sql.Executor, + edVerifier SigVerifier, + proof *wire.InvalidPrevATXProof, +) (types.NodeID, error) { + atx1 := proof.Atx1 + if !edVerifier.Verify(signing.ATX, atx1.SmesherID, atx1.SignedBytes(), atx1.Signature) { + return types.EmptyNodeID, errors.New("atx1: invalid signature") + } + + atx2 := proof.Atx2 + if !edVerifier.Verify(signing.ATX, atx2.SmesherID, atx2.SignedBytes(), atx2.Signature) { + return types.EmptyNodeID, errors.New("atx2: invalid signature") + } + + if atx1.ID() == atx2.ID() { + numInvalidProofsPrevATX.Inc() + return types.EmptyNodeID, errors.New("invalid old prev ATX malfeasance proof: ATX IDs are the same") + } + if atx1.PrevATXID != atx2.PrevATXID { + numInvalidProofsPrevATX.Inc() + return types.EmptyNodeID, errors.New("invalid old prev ATX malfeasance proof: prev ATX IDs are different") + } + return atx1.SmesherID, nil +} diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 636a3128..2cb8898d 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -1068,7 +1068,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_wrongHash(t *testing.T) { require.True(t, malicious) } -func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { +func TestHandler_HandleSyncedMalfeasanceProof_InvalidPostIndex(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) nodeIdH32 := types.Hash32(sig.NodeID()) @@ -1090,8 +1090,9 @@ func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { t.Run("valid malfeasance proof", func(t *testing.T) { db := sql.InMemory() lg := logtest.New(t) - trt := malfeasance.NewMocktortoise(gomock.NewController(t)) - postVerifier := malfeasance.NewMockpostVerifier(gomock.NewController(t)) + ctrl := gomock.NewController(t) + trt := malfeasance.NewMocktortoise(ctrl) + postVerifier := malfeasance.NewMockpostVerifier(ctrl) h := malfeasance.NewHandler( datastore.NewCachedDB(db, lg), @@ -1128,8 +1129,9 @@ func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { t.Run("invalid malfeasance proof (POST valid)", func(t *testing.T) { db := sql.InMemory() lg := logtest.New(t) - trt := malfeasance.NewMocktortoise(gomock.NewController(t)) - postVerifier := malfeasance.NewMockpostVerifier(gomock.NewController(t)) + ctrl := gomock.NewController(t) + trt := malfeasance.NewMocktortoise(ctrl) + postVerifier := malfeasance.NewMockpostVerifier(ctrl) h := malfeasance.NewHandler( datastore.NewCachedDB(db, lg), @@ -1164,8 +1166,9 @@ func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { t.Run("invalid malfeasance proof (ATX signature invalid)", func(t *testing.T) { db := sql.InMemory() lg := logtest.New(t) - trt := malfeasance.NewMocktortoise(gomock.NewController(t)) - postVerifier := malfeasance.NewMockpostVerifier(gomock.NewController(t)) + ctrl := gomock.NewController(t) + trt := malfeasance.NewMocktortoise(ctrl) + postVerifier := malfeasance.NewMockpostVerifier(ctrl) h := malfeasance.NewHandler( datastore.NewCachedDB(db, lg), @@ -1200,3 +1203,215 @@ func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { require.False(t, malicious) }) } + +func TestHandler_HandleSyncedMalfeasanceProof_InvalidPrevATX(t *testing.T) { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + nodeIdH32 := types.Hash32(sig.NodeID()) + + prevATX := *types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: types.EpochID(1), + CommitmentATX: &types.ATXID{1, 2, 3}, + }, + types.Address{}, + nil, + 1, + nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, &prevATX)) + + atx1 := *types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: types.EpochID(2), + PrevATXID: prevATX.ID(), + }, + types.Address{}, + nil, + 1, + nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, &atx1)) + + atx2 := *types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: types.EpochID(3), + PrevATXID: prevATX.ID(), + }, + types.Address{}, + nil, + 1, + nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, &atx2)) + + t.Run("valid malfeasance proof", func(t *testing.T) { + db := sql.InMemory() + lg := logtest.New(t) + ctrl := gomock.NewController(t) + trt := malfeasance.NewMocktortoise(ctrl) + postVerifier := malfeasance.NewMockpostVerifier(ctrl) + + h := malfeasance.NewHandler( + datastore.NewCachedDB(db, lg), + lg, + "self", + []types.NodeID{types.RandomNodeID()}, + signing.NewEdVerifier(), + trt, + postVerifier, + ) + + proof := wire.MalfeasanceProof{ + Layer: types.LayerID(11), + Proof: wire.Proof{ + Type: wire.InvalidPrevATX, + Data: &wire.InvalidPrevATXProof{ + Atx1: *awire.ActivationTxToWireV1(&atx1), + Atx2: *awire.ActivationTxToWireV1(&atx2), + }, + }, + } + + trt.EXPECT().OnMalfeasance(sig.NodeID()) + err := h.HandleSyncedMalfeasanceProof(context.Background(), nodeIdH32, "peer", codec.MustEncode(&proof)) + require.NoError(t, err) + + malicious, err := identities.IsMalicious(db, sig.NodeID()) + require.NoError(t, err) + require.True(t, malicious) + }) + + t.Run("invalid malfeasance proof (same ATX)", func(t *testing.T) { + db := sql.InMemory() + lg := logtest.New(t) + ctrl := gomock.NewController(t) + trt := malfeasance.NewMocktortoise(ctrl) + postVerifier := malfeasance.NewMockpostVerifier(ctrl) + + h := malfeasance.NewHandler( + datastore.NewCachedDB(db, lg), + lg, + "self", + []types.NodeID{types.RandomNodeID()}, + signing.NewEdVerifier(), + trt, + postVerifier, + ) + + proof := wire.MalfeasanceProof{ + Layer: types.LayerID(11), + Proof: wire.Proof{ + Type: wire.InvalidPrevATX, + Data: &wire.InvalidPrevATXProof{ + Atx1: *awire.ActivationTxToWireV1(&atx1), + Atx2: *awire.ActivationTxToWireV1(&atx1), + }, + }, + } + + err := h.HandleSyncedMalfeasanceProof(context.Background(), nodeIdH32, "peer", codec.MustEncode(&proof)) + require.ErrorContains(t, err, "ATX IDs are the same") + + malicious, err := identities.IsMalicious(db, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) + }) + + t.Run("invalid malfeasance proof (prev ATXs differ)", func(t *testing.T) { + db := sql.InMemory() + lg := logtest.New(t) + ctrl := gomock.NewController(t) + trt := malfeasance.NewMocktortoise(ctrl) + postVerifier := malfeasance.NewMockpostVerifier(ctrl) + + h := malfeasance.NewHandler( + datastore.NewCachedDB(db, lg), + lg, + "self", + []types.NodeID{types.RandomNodeID()}, + signing.NewEdVerifier(), + trt, + postVerifier, + ) + + atx3 := *types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: types.EpochID(3), + PrevATXID: atx1.ID(), + }, + types.Address{}, + nil, + 1, + nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, &atx3)) + + proof := wire.MalfeasanceProof{ + Layer: types.LayerID(11), + Proof: wire.Proof{ + Type: wire.InvalidPrevATX, + Data: &wire.InvalidPrevATXProof{ + Atx1: *awire.ActivationTxToWireV1(&atx1), + Atx2: *awire.ActivationTxToWireV1(&atx3), + }, + }, + } + + err := h.HandleSyncedMalfeasanceProof(context.Background(), nodeIdH32, "peer", codec.MustEncode(&proof)) + require.ErrorContains(t, err, "prev ATX IDs are different") + + malicious, err := identities.IsMalicious(db, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) + }) + + t.Run("invalid malfeasance proof (ATX signature invalid)", func(t *testing.T) { + db := sql.InMemory() + lg := logtest.New(t) + ctrl := gomock.NewController(t) + trt := malfeasance.NewMocktortoise(ctrl) + postVerifier := malfeasance.NewMockpostVerifier(ctrl) + + h := malfeasance.NewHandler( + datastore.NewCachedDB(db, lg), + lg, + "self", + []types.NodeID{types.RandomNodeID()}, + signing.NewEdVerifier(), + trt, + postVerifier, + ) + + atx3 := *types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: types.EpochID(3), + PrevATXID: atx1.ID(), + }, + types.Address{}, + nil, + 1, + nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, &atx3)) + atx3.PrevATXID = prevATX.ID() // invalidate signature by changing content + + proof := wire.MalfeasanceProof{ + Layer: types.LayerID(11), + Proof: wire.Proof{ + Type: wire.InvalidPrevATX, + Data: &wire.InvalidPrevATXProof{ + Atx1: *awire.ActivationTxToWireV1(&atx1), + Atx2: *awire.ActivationTxToWireV1(&atx3), + }, + }, + } + + err := h.HandleSyncedMalfeasanceProof(context.Background(), nodeIdH32, "peer", codec.MustEncode(&proof)) + require.ErrorContains(t, err, "invalid signature") + + malicious, err := identities.IsMalicious(db, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) + }) +} diff --git a/malfeasance/metrics.go b/malfeasance/metrics.go index 8e2eccfa..7592c33c 100644 --- a/malfeasance/metrics.go +++ b/malfeasance/metrics.go @@ -13,6 +13,7 @@ const ( multiBallots = "ballot" hareEquivocate = "hare_eq" invalidPostIndex = "invalid_post_index" + invalidPrevATX = "invalid_prev_atx" ) var ( @@ -29,6 +30,7 @@ var ( numProofsBallot = numProofs.WithLabelValues(multiBallots) numProofsHare = numProofs.WithLabelValues(hareEquivocate) numProofsPostIndex = numProofs.WithLabelValues(invalidPostIndex) + numProofsPrevATX = numProofs.WithLabelValues(invalidPrevATX) numInvalidProofs = metrics.NewCounter( "num_invalid_proofs", @@ -39,9 +41,10 @@ var ( }, ) + numMalformed = numInvalidProofs.WithLabelValues("mal") numInvalidProofsATX = numInvalidProofs.WithLabelValues(multiATXs) numInvalidProofsBallot = numInvalidProofs.WithLabelValues(multiBallots) numInvalidProofsHare = numInvalidProofs.WithLabelValues(hareEquivocate) numInvalidProofsPostIndex = numInvalidProofs.WithLabelValues(invalidPostIndex) - numMalformed = numInvalidProofs.WithLabelValues("mal") + numInvalidProofsPrevATX = numInvalidProofs.WithLabelValues(invalidPrevATX) ) diff --git a/malfeasance/wire/malfeasance.go b/malfeasance/wire/malfeasance.go index c50480cd..0f71c207 100644 --- a/malfeasance/wire/malfeasance.go +++ b/malfeasance/wire/malfeasance.go @@ -15,13 +15,14 @@ import ( "github.com/spacemeshos/go-spacemesh/log" ) -//go:generate scalegen -types MalfeasanceProof,MalfeasanceGossip,AtxProof,BallotProof,HareProof,AtxProofMsg,BallotProofMsg,HareProofMsg,HareMetadata,InvalidPostIndexProof +//go:generate scalegen -types MalfeasanceProof,MalfeasanceGossip,AtxProof,BallotProof,HareProof,AtxProofMsg,BallotProofMsg,HareProofMsg,HareMetadata,InvalidPostIndexProof,InvalidPrevATXProof const ( MultipleATXs byte = iota + 1 MultipleBallots HareEquivocation InvalidPostIndex + InvalidPrevATX ) type MalfeasanceProof struct { @@ -71,11 +72,19 @@ func (mp *MalfeasanceProof) MarshalLogObject(encoder log.ObjectEncoder) error { encoder.AddString("type", "invalid post index") p, ok := mp.Proof.Data.(*InvalidPostIndexProof) if ok { - atx := wire.ActivationTxFromWireV1(&p.Atx) - encoder.AddString("atx_id", atx.ID().String()) + encoder.AddString("atx_id", p.Atx.ID().String()) encoder.AddString("smesher", p.Atx.SmesherID.String()) encoder.AddUint32("invalid index", p.InvalidIdx) } + case InvalidPrevATX: + encoder.AddString("type", "invalid prev atx") + p, ok := mp.Proof.Data.(*InvalidPrevATXProof) + if ok { + encoder.AddString("atx1_id", p.Atx2.ID().String()) + encoder.AddString("atx2_id", p.Atx2.ID().String()) + encoder.AddString("smesher", p.Atx1.SmesherID.String()) + encoder.AddString("prev_atx", p.Atx1.PrevATXID.String()) + } default: encoder.AddString("type", "unknown") } @@ -153,6 +162,14 @@ func (e *Proof) DecodeScale(dec *scale.Decoder) (int, error) { } e.Data = &proof total += n + case InvalidPrevATX: + var proof InvalidPrevATXProof + n, err := proof.DecodeScale(dec) + if err != nil { + return total, err + } + e.Data = &proof + total += n default: return total, errors.New("unknown malfeasance proof type") } @@ -292,6 +309,13 @@ func (m *HareProofMsg) SignedBytes() []byte { return m.InnerMsg.ToBytes() } +// InvalidPrevAtxProof is a proof that a smesher published an ATX with an old previous ATX ID. +// The proof contains two ATXs that reference the same previous ATX. +type InvalidPrevATXProof struct { + Atx1 wire.ActivationTxV1 + Atx2 wire.ActivationTxV1 +} + func MalfeasanceInfo(smesher types.NodeID, mp *MalfeasanceProof) string { var b strings.Builder b.WriteString(fmt.Sprintf("generate layer: %v\n", mp.Layer)) @@ -359,15 +383,25 @@ func MalfeasanceInfo(smesher types.NodeID, mp *MalfeasanceProof) string { case InvalidPostIndex: p, ok := mp.Proof.Data.(*InvalidPostIndexProof) if ok { - atx := wire.ActivationTxFromWireV1(&p.Atx) b.WriteString( fmt.Sprintf( "cause: smesher published ATX %s with invalid post index %d in epoch %d\n", - atx.ID().ShortString(), + p.Atx.ID().ShortString(), p.InvalidIdx, p.Atx.Publish, )) } + case InvalidPrevATX: + p, ok := mp.Proof.Data.(*InvalidPrevATXProof) + if ok { + b.WriteString( + fmt.Sprintf( + "cause: smesher published ATX %s with invalid previous ATX %s in epoch %d\n", + p.Atx1.ID().ShortString(), + p.Atx2.ID().ShortString(), + p.Atx1.Publish, + )) + } } return b.String() } diff --git a/malfeasance/wire/malfeasance_scale.go b/malfeasance/wire/malfeasance_scale.go index 625292f0..3ec88a1a 100644 --- a/malfeasance/wire/malfeasance_scale.go +++ b/malfeasance/wire/malfeasance_scale.go @@ -386,3 +386,39 @@ func (t *InvalidPostIndexProof) DecodeScale(dec *scale.Decoder) (total int, err } return total, nil } + +func (t *InvalidPrevATXProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := t.Atx1.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Atx2.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *InvalidPrevATXProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := t.Atx1.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Atx2.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} diff --git a/malfeasance/wire/malfeasance_test.go b/malfeasance/wire/malfeasance_test.go index 70cc9426..8a478f94 100644 --- a/malfeasance/wire/malfeasance_test.go +++ b/malfeasance/wire/malfeasance_test.go @@ -8,9 +8,12 @@ import ( "github.com/spacemeshos/go-scale/tester" "github.com/stretchr/testify/require" + "github.com/spacemeshos/go-spacemesh/activation" + awire "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" + "github.com/spacemeshos/go-spacemesh/signing" ) func TestMain(m *testing.M) { @@ -182,6 +185,87 @@ func Test_HareMetadata_Equivocation(t *testing.T) { require.False(t, hm1.Equivocation(&hm2)) } +func TestCodec_InvalidPostIndex(t *testing.T) { + lid := types.LayerID(11) + atx := types.NewActivationTx( + types.NIPostChallenge{PublishEpoch: lid.GetEpoch()}, + types.Address{1, 2, 3}, + nil, 10, nil, + ) + + proof := &wire.MalfeasanceProof{ + Layer: lid, + Proof: wire.Proof{ + Type: wire.InvalidPostIndex, + Data: &wire.InvalidPostIndexProof{ + Atx: *awire.ActivationTxToWireV1(atx), + InvalidIdx: 5, + }, + }, + } + encoded, err := codec.Encode(proof) + require.NoError(t, err) + + var decoded wire.MalfeasanceProof + require.NoError(t, codec.Decode(encoded, &decoded)) + require.Equal(t, *proof, decoded) +} + +func TestCodec_InvalidPrevATX(t *testing.T) { + lid := types.LayerID(45) + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + prev := types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: lid.GetEpoch() - 2, + }, + types.Address{1, 2, 3}, + nil, 10, nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, prev)) + + atx1 := types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: lid.GetEpoch() - 1, + PrevATXID: prev.ID(), + }, + types.Address{1, 2, 3}, + nil, 10, nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, atx1)) + + atx2 := types.NewActivationTx( + types.NIPostChallenge{ + PublishEpoch: lid.GetEpoch(), + PrevATXID: prev.ID(), + }, + types.Address{1, 2, 3}, + nil, 10, nil, + ) + require.NoError(t, activation.SignAndFinalizeAtx(sig, atx2)) + + proof := &wire.MalfeasanceProof{ + Layer: lid, + Proof: wire.Proof{ + Type: wire.InvalidPrevATX, + Data: &wire.InvalidPrevATXProof{ + Atx1: *awire.ActivationTxToWireV1(atx1), + Atx2: *awire.ActivationTxToWireV1(atx2), + }, + }, + } + encoded, err := codec.Encode(proof) + require.NoError(t, err) + + var decoded wire.MalfeasanceProof + require.NoError(t, codec.Decode(encoded, &decoded)) + // require.NoError(t, decoded.Proof.Data.(*wire.InvalidPrevATXProof).Atx1.Initialize()) + // require.NoError(t, decoded.Proof.Data.(*wire.InvalidPrevATXProof).Atx2.Initialize()) + require.Equal(t, *proof, decoded) +} + func FuzzProofConsistency(f *testing.F) { tester.FuzzConsistency[wire.Proof](f, func(p *wire.Proof, c fuzz.Continue) { switch c.Intn(3) { @@ -200,6 +284,16 @@ func FuzzProofConsistency(f *testing.F) { data := wire.HareProof{} c.Fuzz(&data) p.Data = &data + case 3: + p.Type = wire.InvalidPostIndex + data := wire.InvalidPostIndexProof{} + c.Fuzz(&data) + p.Data = &data + case 4: + p.Type = wire.InvalidPrevATX + data := wire.InvalidPrevATXProof{} + c.Fuzz(&data) + p.Data = &data } }) } diff --git a/node/node.go b/node/node.go index 356323c2..d143e661 100644 --- a/node/node.go +++ b/node/node.go @@ -36,7 +36,6 @@ import ( "google.golang.org/grpc/keepalive" "github.com/spacemeshos/go-spacemesh/activation" - "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/api/grpcserver" "github.com/spacemeshos/go-spacemesh/api/grpcserver/v2alpha1" "github.com/spacemeshos/go-spacemesh/atxsdata" @@ -75,7 +74,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" dbmetrics "github.com/spacemeshos/go-spacemesh/sql/metrics" @@ -1894,6 +1892,15 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { datastore.WithConfig(app.Config.Cache), ) + if app.Config.ScanMalfeasantATXs { + app.log.With().Info("checking DB for malicious ATXs") + start = time.Now() + if err := activation.CheckPrevATXs(ctx, app.log.Zap(), app.db); err != nil { + return fmt.Errorf("malicious ATX check: %w", err) + } + app.log.With().Info("malicious ATX check completed", log.Duration("duration", time.Since(start))) + } + migrations, err = sql.LocalMigrations() if err != nil { return fmt.Errorf("load local migrations: %w", err) @@ -1933,9 +1940,6 @@ func (app *App) Start(ctx context.Context) error { }) } - // uncomment to verify ATXs signatures - // app.verifyDB(ctx) - // app blocks until it receives a signal to exit // this signal may come from the node or from sig-abort (ctrl-c) select { @@ -1946,48 +1950,6 @@ func (app *App) Start(ctx context.Context) error { } } -// verifyDB performs a verification of ATX signatures in the database. -// -//lint:ignore U1000 This function is currently unused but is left here for future use. -func (app *App) verifyDB(ctx context.Context) { - app.eg.Go(func() error { - app.log.Info("checking ATX signatures") - count := 0 - - // check ATX signatures - atxs.IterateAtxsOps(app.cachedDB, builder.Operations{}, func(atx *types.VerifiedActivationTx) bool { - select { - case <-ctx.Done(): - // stop on context cancellation - return false - default: - } - - // verify atx signature - // TODO: use atx handler to verify signature - if !app.edVerifier.Verify( - signing.ATX, - atx.SmesherID, wire.ActivationTxToWireV1(atx.ActivationTx).SignedBytes(), - atx.Signature, - ) { - app.log.With().Error("ATX signature verification failed", - log.Stringer("atx_id", atx.ID()), - log.Stringer("smesher", atx.SmesherID), - ) - } - - count++ - if count%1000 == 0 { - app.log.With().Info("verifying ATX signatures", log.Int("count", count)) - } - return true - }) - - app.log.With().Info("ATX signatures verified", log.Int("count", count)) - return nil - }) -} - func (app *App) startSynchronous(ctx context.Context) (err error) { // notify anyone who might be listening that the app has finished starting. // this can be used by, e.g., app tests. diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index d4650193..5f747afa 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -205,11 +205,12 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er return id, err } -// GetIDByEpochAndNodeID gets an ATX ID for a given epoch and node ID. -func GetIDByEpochAndNodeID(db sql.Executor, epoch types.EpochID, nodeID types.NodeID) (id types.ATXID, err error) { +// PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch. +// It returns the newest ATX ID that was published before the given public epoch. +func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { - stmt.BindInt64(1, int64(epoch)) - stmt.BindBytes(2, nodeID.Bytes()) + stmt.BindBytes(1, nodeID.Bytes()) + stmt.BindInt64(2, int64(pubEpoch)) } dec := func(stmt *sql.Statement) bool { stmt.ColumnBytes(0, id[:]) @@ -218,60 +219,38 @@ func GetIDByEpochAndNodeID(db sql.Executor, epoch types.EpochID, nodeID types.No if rows, err := db.Exec(` select id from atxs - where epoch = ?1 and pubkey = ?2 + where pubkey = ?1 and epoch < ?2 + order by epoch desc limit 1;`, enc, dec); err != nil { - return types.ATXID{}, fmt.Errorf("exec nodeID %v: %w", nodeID, err) + return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err) } else if rows == 0 { - return types.ATXID{}, fmt.Errorf("exec nodeID %s: %w", nodeID, sql.ErrNotFound) + return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound) } return id, err } -// IterateIDsByEpoch invokes the specified callback for each ATX ID in a given epoch. -// It stops if the callback returns an error. -func IterateIDsByEpoch( - db sql.Executor, - epoch types.EpochID, - callback func(total int, id types.ATXID) error, -) error { - if sql.IsCached(db) { - // If the slices are cached, let's not do more SELECTs - ids, err := GetIDsByEpoch(context.Background(), db, epoch) - if err != nil { - return err - } - for _, id := range ids { - if err := callback(len(ids), id); err != nil { - return err - } - } - return nil - } - - var callbackErr error +// GetIDByEpochAndNodeID gets an ATX ID for a given epoch and node ID. +func GetIDByEpochAndNodeID(db sql.Executor, epoch types.EpochID, nodeID types.NodeID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { stmt.BindInt64(1, int64(epoch)) + stmt.BindBytes(2, nodeID.Bytes()) } dec := func(stmt *sql.Statement) bool { - var id types.ATXID - total := stmt.ColumnInt(0) - stmt.ColumnBytes(1, id[:]) - if callbackErr = callback(total, id); callbackErr != nil { - return false - } + stmt.ColumnBytes(0, id[:]) return true } - // Get total count in the same select statement to avoid the need for transaction - if _, err := db.Exec( - "select (select count(*) from atxs where epoch = ?1) as total, id from atxs where epoch = ?1;", - enc, dec, - ); err != nil { - return fmt.Errorf("exec epoch %v: %w", epoch, err) + if rows, err := db.Exec(` + select id from atxs + where epoch = ?1 and pubkey = ?2 + limit 1;`, enc, dec); err != nil { + return types.ATXID{}, fmt.Errorf("exec nodeID %v: %w", nodeID, err) + } else if rows == 0 { + return types.ATXID{}, fmt.Errorf("exec nodeID %s: %w", nodeID, sql.ErrNotFound) } - return callbackErr + return id, err } // GetIDsByEpoch gets ATX IDs for a given epoch. @@ -401,7 +380,7 @@ func AddGettingNonce(db sql.Executor, atx *types.VerifiedActivationTx) (*types.V if err == nil { err = add(db, atx, &nonce) if err != nil { - return nil, err + return &nonce, err } else { return &nonce, nil } @@ -809,3 +788,43 @@ func PoetProofRef(ctx context.Context, db sql.Executor, id types.ATXID) (types.P return types.PoetProofRef(atx.NIPost.PostMetadata.Challenge), nil } + +type PrevATXCollision struct { + NodeID1 types.NodeID + ATX1 types.ATXID + + NodeID2 types.NodeID + ATX2 types.ATXID +} + +func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) { + var result []PrevATXCollision + + dec := func(stmt *sql.Statement) bool { + var nodeID1, nodeID2 types.NodeID + stmt.ColumnBytes(0, nodeID1[:]) + stmt.ColumnBytes(1, nodeID2[:]) + + var id1, id2 types.ATXID + stmt.ColumnBytes(2, id1[:]) + stmt.ColumnBytes(3, id2[:]) + + result = append(result, PrevATXCollision{ + NodeID1: nodeID1, + ATX1: id1, + + NodeID2: nodeID2, + ATX2: id2, + }) + return true + } + if _, err := db.Exec(` + SELECT t1.pubkey, t2.pubkey, t1.id, t2.id + FROM atxs t1 + INNER JOIN atxs t2 ON t1.prev_id = t2.prev_id + WHERE t1.id < t2.id;`, nil, dec); err != nil { + return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err) + } + + return result, nil +} diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index a3f35cf9..302701a9 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -498,35 +498,6 @@ func TestGetIDsByEpochCached(t *testing.T) { require.Equal(t, 16, db.QueryCount()) // not incremented after Add } -func TestForIDsByEpochEarlyStop(t *testing.T) { - db := sql.InMemory() - - e1 := types.EpochID(1) - m := make(map[types.ATXID]struct{}) - for i := 0; i < 4; i++ { - sig, err := signing.NewEdSigner() - require.NoError(t, err) - atx, err := newAtx(sig, withPublishEpoch(e1)) - require.NoError(t, err) - require.NoError(t, atxs.Add(db, atx)) - m[atx.ID()] = struct{}{} - } - - n := 0 - err := atxs.IterateIDsByEpoch(db, e1, func(total int, id types.ATXID) error { - require.Equal(t, 4, total) - delete(m, id) - n++ - if n >= 2 { - return errors.New("test error") - } - return nil - }) - require.ErrorContains(t, err, "test error") - require.Equal(t, 2, n) - require.Len(t, m, 2) -} - func TestVRFNonce(t *testing.T) { // Arrange db := sql.InMemory() @@ -997,3 +968,57 @@ func TestLatest(t *testing.T) { }) } } + +func Test_PrevATXCollisions(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + // create two ATXs with the same PrevATXID + prevATXID := types.RandomATXID() + + atx1, err := newAtx(sig, withPublishEpoch(1), withPrevATXID(prevATXID)) + require.NoError(t, err) + atx2, err := newAtx(sig, withPublishEpoch(2), withPrevATXID(prevATXID)) + require.NoError(t, err) + + require.NoError(t, atxs.Add(db, atx1)) + require.NoError(t, atxs.Add(db, atx2)) + + // verify that the ATXs were added + got1, err := atxs.Get(db, atx1.ID()) + require.NoError(t, err) + require.Equal(t, atx1, got1) + + got2, err := atxs.Get(db, atx2.ID()) + require.NoError(t, err) + require.Equal(t, atx2, got2) + + // add 10 valid ATXs by 10 other smeshers + atxMap := make(map[types.NodeID][]*types.VerifiedActivationTx) + for i := 2; i < 12; i++ { + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + + if len(atxMap[otherSig.NodeID()]) == 0 { + atx, err := newAtx(otherSig, withPublishEpoch(types.EpochID(i))) + require.NoError(t, err) + require.NoError(t, atxs.Add(db, atx)) + } else { + atx, err := newAtx(otherSig, withPublishEpoch(types.EpochID(i)), + withPrevATXID(atxMap[otherSig.NodeID()][len(atxMap[otherSig.NodeID()])-1].ID()), + ) + require.NoError(t, err) + require.NoError(t, atxs.Add(db, atx)) + } + } + + // get the collisions + got, err := atxs.PrevATXCollisions(db) + require.NoError(t, err) + require.Len(t, got, 1) + + require.Equal(t, sig.NodeID(), got[0].NodeID1) + require.Equal(t, sig.NodeID(), got[0].NodeID2) + require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{got[0].ATX1, got[0].ATX2}) +} diff --git a/sql/identities/identities.go b/sql/identities/identities.go index f8d523d0..48ef4b36 100644 --- a/sql/identities/identities.go +++ b/sql/identities/identities.go @@ -107,7 +107,7 @@ func IterateMalicious( return callbackErr } -// GetMalicious retrives malicious node IDs from the database. +// GetMalicious retrieves malicious node IDs from the database. func GetMalicious(db sql.Executor) (nids []types.NodeID, err error) { if err = IterateMalicious(db, func(total int, nid types.NodeID) error { if nids == nil { diff --git a/sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql b/sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql index e69de29b..7e4a3435 100644 --- a/sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql +++ b/sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql @@ -0,0 +1 @@ +-- Migration is done entirely in code diff --git a/sql/migrations/state_0017_migration_test.go b/sql/migrations/state_0017_migration_test.go index a5dbcbbe..9c71537c 100644 --- a/sql/migrations/state_0017_migration_test.go +++ b/sql/migrations/state_0017_migration_test.go @@ -2,6 +2,8 @@ package migrations import ( "context" + "path/filepath" + "strings" "testing" "time" @@ -60,6 +62,40 @@ func addAtx( return vAtx.ID() } +func Test_0017Migration_CompatibleSQL(t *testing.T) { + file := filepath.Join(t.TempDir(), "test1.db") + db, err := sql.Open("file:"+file, + sql.WithMigration(New0017Migration(zaptest.NewLogger(t))), + ) + require.NoError(t, err) + + var sqls1 []string + _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { + sql := stmt.ColumnText(0) + sql = strings.Join(strings.Fields(sql), " ") // remove whitespace + sqls1 = append(sqls1, sql) + return true + }) + require.NoError(t, err) + require.NoError(t, db.Close()) + + file = filepath.Join(t.TempDir(), "test2.db") + db, err = sql.Open("file:" + file) + require.NoError(t, err) + + var sqls2 []string + _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { + sql := stmt.ColumnText(0) + sql = strings.Join(strings.Fields(sql), " ") // remove whitespace + sqls2 = append(sqls2, sql) + return true + }) + require.NoError(t, err) + require.NoError(t, db.Close()) + + require.Equal(t, sqls1, sqls2) +} + func Test0017Migration(t *testing.T) { for i := 0; i < 10; i++ { db := sql.InMemory()