diff --git a/bitfield.go b/bitfield.go index d6f12b0..2f77646 100644 --- a/bitfield.go +++ b/bitfield.go @@ -3,28 +3,33 @@ package bitfield // NOTE: Don't bother replacing the divisions/modulo with shifts/ands, go is smart. import ( + "fmt" "math/bits" ) // NewBitfield creates a new fixed-sized Bitfield (allocated up-front). -// -// Panics if size is not a multiple of 8. -func NewBitfield(size int) Bitfield { +func NewBitfield(size int) (Bitfield, error) { + if size < 0 { + return nil, fmt.Errorf("bitfield size must be positive; got %d", size) + } if size%8 != 0 { - panic("Bitfield size must be a multiple of 8") + return nil, fmt.Errorf("bitfield size must be a multiple of 8; got %d", size) } - return make([]byte, size/8) + return make([]byte, size/8), nil } // FromBytes constructs a new bitfield from a serialized bitfield. -func FromBytes(size int, bits []byte) Bitfield { - bf := NewBitfield(size) +func FromBytes(size int, bits []byte) (Bitfield, error) { + bf, err := NewBitfield(size) + if err != nil { + return nil, err + } start := len(bf) - len(bits) if start < 0 { - panic("bitfield too small") + return nil, fmt.Errorf("bitfield too small: got %d; need %d", size, len(bits)*8) } copy(bf[start:], bits) - return bf + return bf, nil } func (bf Bitfield) offset(i int) (uint, uint8) { diff --git a/bitfield_test.go b/bitfield_test.go index 5f51382..0ed7f91 100644 --- a/bitfield_test.go +++ b/bitfield_test.go @@ -9,7 +9,8 @@ import ( ) func TestExhaustive24(t *testing.T) { - bf := NewBitfield(24) + bf, err := NewBitfield(24) + assertNoError(t, err) max := 1 << 24 bint := new(big.Int) @@ -58,7 +59,8 @@ func TestExhaustive24(t *testing.T) { } func TestBitfield(t *testing.T) { - bf := NewBitfield(128) + bf, err := NewBitfield(128) + assertNoError(t, err) if bf.OnesBefore(20) != 0 { t.Fatal("expected no bits set") } @@ -91,10 +93,20 @@ func TestBitfield(t *testing.T) { } } +func TestBadSizeFails(t *testing.T) { + for _, size := range [...]int{-8, 2, 1337, -3} { + _, err := NewBitfield(size) + if err == nil { + t.Fatalf("missing error for %d sized bitfield", size) + } + } +} + var benchmarkSize = 512 func BenchmarkBitfield(t *testing.B) { - bf := NewBitfield(benchmarkSize) + bf, err := NewBitfield(benchmarkSize) + assertNoError(t, err) t.ResetTimer() for i := 0; i < t.N; i++ { if bf.Bit(i % benchmarkSize) { @@ -123,13 +135,14 @@ func BenchmarkBitfield(t *testing.B) { } } -func BenchmarkOnes(t *testing.B) { - bf := NewBitfield(benchmarkSize) - t.ResetTimer() - for i := 0; i < t.N; i++ { +func BenchmarkOnes(b *testing.B) { + bf, err := NewBitfield(benchmarkSize) + assertNoError(b, err) + b.ResetTimer() + for i := 0; i < b.N; i++ { for j := 0; j*4 < benchmarkSize; j++ { if bf.Ones() != j { - t.Fatal("bad", i) + b.Fatal("bad", i) } bf.SetBit(j * 4) } @@ -139,14 +152,16 @@ func BenchmarkOnes(t *testing.B) { } } -func BenchmarkBytes(t *testing.B) { - bfa := NewBitfield(211) - bfb := NewBitfield(211) - for j := 0; j*4 < 211; j++ { +func BenchmarkBytes(b *testing.B) { + bfa, err := NewBitfield(216) + assertNoError(b, err) + bfb, err := NewBitfield(216) + assertNoError(b, err) + for j := 0; j*4 < 216; j++ { bfa.SetBit(j * 4) } - t.ResetTimer() - for i := 0; i < t.N; i++ { + b.ResetTimer() + for i := 0; i < b.N; i++ { bfb.SetBytes(bfa.Bytes()) } } @@ -180,3 +195,18 @@ func BenchmarkBigInt(t *testing.B) { } } } + +func FuzzFromBytes(f *testing.F) { + f.Fuzz(func(_ *testing.T, size int, bytes []byte) { + if size > 1<<20 { // We relly on consumers for limit checks, hopefully they understand that a New... factory allocates memory. + return + } + FromBytes(size, bytes) + }) +} + +func assertNoError(t testing.TB, e error) { + if e != nil { + t.Fatal(e) + } +} diff --git a/version.json b/version.json index 77dfc03..a905d1a 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v1.0.0" + "version": "v1.1.0" }