Skip to content

Commit

Permalink
Implement sequence tracking
Browse files Browse the repository at this point in the history
Based on the MIT Kerberos implementation. It tracks the 64 previously seen
sequence numbers for the purposes of out of order delivery and replay
protection.

Fixes jcmturner#419
  • Loading branch information
bodgit committed Dec 22, 2020
1 parent 663478b commit 95c20ce
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
91 changes: 91 additions & 0 deletions v8/gssapi/sequence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package gssapi

import (
"errors"
"math"
"sync"
)

var (
errDuplicateToken = errors.New("duplicate per-message token detected")
errOldToken = errors.New("timed-out per-message token detected")
errUnseqToken = errors.New("reordered (early) per-message token detected")
errGapToken = errors.New("skipped predecessor token(s) detected")
)

// SequenceState tracks previously seen sequence numbers for message replay
// and/or sequence protection
type SequenceState struct {
m sync.Mutex
doReplay bool
doSequence bool
base uint64
next uint64
receiveMask uint64
sequenceMask uint64
}

// NewSequenceState returns a new SequenceState seeded with sequenceNumber
// with doReplay and doSequence controlling replay and sequence protection
// respectively and wide controlling whether sequence numbers are expected to
// wrap at a 32- or 64-bit boundary.
func NewSequenceState(sequenceNumber uint64, doReplay, doSequence, wide bool) *SequenceState {
ss := &SequenceState{
doReplay: doReplay,
doSequence: doSequence,
base: sequenceNumber,
}
if wide {
ss.sequenceMask = math.MaxUint64
} else {
ss.sequenceMask = math.MaxUint32
}
return ss
}

// Check the next sequence number. Sequence protection requires the sequence
// number to increase sequentially with no duplicates or out of order delivery.
// Replay protection relaxes these restrictions to permit limited out of order
// delivery.
func (ss *SequenceState) Check(sequenceNumber uint64) error {
if !ss.doReplay && !ss.doSequence {
return nil
}

ss.m.Lock()
defer ss.m.Unlock()

relativeSequenceNumber := (sequenceNumber - ss.base) & ss.sequenceMask

if relativeSequenceNumber >= ss.next {
offset := relativeSequenceNumber - ss.next
ss.receiveMask = ss.receiveMask<<(offset+1) | 1
ss.next = (relativeSequenceNumber + 1) & ss.sequenceMask

if offset > 0 && ss.doSequence {
return errGapToken
}

return nil
}

offset := ss.next - relativeSequenceNumber

if offset > 64 {
if ss.doSequence {
return errUnseqToken
}
return errOldToken
}

bit := uint64(1) << (offset - 1)
if ss.doReplay && ss.receiveMask&bit != 0 {
return errDuplicateToken
}
ss.receiveMask |= bit
if ss.doSequence {
return errUnseqToken
}

return nil
}
115 changes: 115 additions & 0 deletions v8/gssapi/sequence_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package gssapi

import (
"testing"

"github.com/stretchr/testify/assert"
)

func makeRange(min, max uint64) []uint64 {
a := make([]uint64, max-min+1)
for i := range a {
a[i] = min + uint64(i)
}
return a
}

func TestSequenceState(t *testing.T) {
tables := map[string]struct {
base uint64
doReplay bool
doSequence bool
wide bool
sequence []uint64
err error
}{
"noop": {
0,
false,
false,
false,
makeRange(0, 64),
nil,
},
"ok": {
0,
true,
true,
true,
makeRange(0, 64),
nil,
},
"replay skip": {
0,
true,
false,
true,
append(makeRange(0, 64), 66),
nil,
},
"sequence skip": {
0,
false,
true,
true,
append(makeRange(0, 64), 66),
errGapToken,
},
"replay too old": {
0,
true,
false,
true,
append(makeRange(0, 64), 0),
errOldToken,
},
"sequence too old": {
0,
false,
true,
true,
append(makeRange(0, 64), 0),
errUnseqToken,
},
"replay duplicate": {
0,
true,
false,
true,
append(makeRange(0, 64), 64),
errDuplicateToken,
},
"sequence duplicate": {
0,
false,
true,
true,
append(makeRange(0, 64), 64),
errUnseqToken,
},
"replay out of order": {
0,
true,
false,
true,
append(makeRange(0, 64), 66, 65),
nil,
},
}

for name, table := range tables {
t.Run(name, func(t *testing.T) {
ss := NewSequenceState(table.base, table.doReplay, table.doSequence, table.wide)

var err error
for _, next := range table.sequence {
err = ss.Check(next)
if err != nil {
break
}
}

assert.Equal(t, table.err, err)
})
}
}

0 comments on commit 95c20ce

Please sign in to comment.