Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use go slices #49

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions bufferpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package gozstd

import (
"bytes"
"sync"
)

var compInBufPool = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, cstreamInBufSize))
},
}

var compOutBufPool = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, cstreamOutBufSize))
},
}

var decInBufPool = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, dstreamInBufSize))
},
}

var decOutBufPool = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, dstreamOutBufSize))
},
}
100 changes: 62 additions & 38 deletions gozstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,27 @@ package gozstd
#include "zstd.h"
#include "zstd_errors.h"

#include <stdint.h> // for uintptr_t

// The following *_wrapper functions allow avoiding memory allocations
// durting calls from Go.
// See https://github.com/golang/go/issues/24450 .

static size_t ZSTD_compressCCtx_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, int compressionLevel) {
return ZSTD_compressCCtx((ZSTD_CCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, compressionLevel);
static size_t ZSTD_compressCCtx_wrapper(void *ctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, int compressionLevel) {
return ZSTD_compressCCtx((ZSTD_CCtx*)ctx, dst, dstCapacity, src, srcSize, compressionLevel);
}

static size_t ZSTD_compress_usingCDict_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, uintptr_t cdict) {
static size_t ZSTD_compress_usingCDict_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize, void *cdict) {
return ZSTD_compress_usingCDict((ZSTD_CCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, (const ZSTD_CDict*)cdict);
}

static size_t ZSTD_decompressDCtx_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize) {
static size_t ZSTD_decompressDCtx_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize) {
return ZSTD_decompressDCtx((ZSTD_DCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize);
}

static size_t ZSTD_decompress_usingDDict_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, uintptr_t ddict) {
static size_t ZSTD_decompress_usingDDict_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize, void *ddict) {
return ZSTD_decompress_usingDDict((ZSTD_DCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, (const ZSTD_DDict*)ddict);
}

static unsigned long long ZSTD_getFrameContentSize_wrapper(uintptr_t src, size_t srcSize) {
static unsigned long long ZSTD_getFrameContentSize_wrapper(void *src, size_t srcSize) {
return ZSTD_getFrameContentSize((const void*)src, srcSize);
}
*/
Expand All @@ -38,6 +36,7 @@ import "C"
import (
"fmt"
"io"
"reflect"
"runtime"
"sync"
"unsafe"
Expand All @@ -46,6 +45,8 @@ import (
// DefaultCompressionLevel is the default compression level.
const DefaultCompressionLevel = 3 // Obtained from ZSTD_CLEVEL_DEFAULT.

const maxFrameContentSize = 256 << 20 // 256 MB

// Compress appends compressed src to dst and returns the result.
func Compress(dst, src []byte) []byte {
return compressDictLevel(dst, src, nil, DefaultCompressionLevel)
Expand Down Expand Up @@ -146,36 +147,53 @@ func compress(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressi
return dst
}

// noescape hides a pointer from escape analysis. It is the identity function
// but escape analysis doesn't think the output depends on the input.
// noescape is inlined and currently compiles down to zero instructions.
// This is copied from go's strings.Builder. Allows us to use stack-allocated
// slices.
//go:nosplit
//go:nocheckptr
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}

func compressInternal(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressionLevel int, mustSucceed bool) C.size_t {
// using noescape will allow this to work with stack-allocated slices
dstHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&dst)))
srcHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src)))

if cd != nil {
result := C.ZSTD_compress_usingCDict_wrapper(
C.uintptr_t(uintptr(unsafe.Pointer(cctxDict.cctx))),
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
unsafe.Pointer(cctxDict.cctx),
unsafe.Pointer(dstHdr.Data),
C.size_t(cap(dst)),
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
unsafe.Pointer(srcHdr.Data),
C.size_t(len(src)),
C.uintptr_t(uintptr(unsafe.Pointer(cd.p))))
unsafe.Pointer(cd.p))
// Prevent from GC'ing of dst and src during CGO call above.
runtime.KeepAlive(dst)
runtime.KeepAlive(src)
if mustSucceed {
ensureNoError("ZSTD_compress_usingCDict_wrapper", result)
ensureNoError("ZSTD_compress_usingCDict", result)
}
return result
}
result := C.ZSTD_compressCCtx_wrapper(
C.uintptr_t(uintptr(unsafe.Pointer(cctx.cctx))),
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
unsafe.Pointer(cctx.cctx),
unsafe.Pointer(dstHdr.Data),
C.size_t(cap(dst)),
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
unsafe.Pointer(srcHdr.Data),
C.size_t(len(src)),
C.int(compressionLevel))
// Prevent from GC'ing of dst and src during CGO call above.
runtime.KeepAlive(dst)
runtime.KeepAlive(src)
if mustSucceed {
ensureNoError("ZSTD_compressCCtx_wrapper", result)
ensureNoError("ZSTD_compressCCtx", result)
}

return result
}

Expand Down Expand Up @@ -254,17 +272,15 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte
}

// Slow path - resize dst to fit decompressed data.
decompressBound := int(C.ZSTD_getFrameContentSize_wrapper(
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), C.size_t(len(src))))
// Prevent from GC'ing of src during CGO call above.
runtime.KeepAlive(src)
switch uint64(decompressBound) {
case uint64(C.ZSTD_CONTENTSIZE_UNKNOWN):
srcHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src)))
contentSize := C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src)))
switch {
case contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN || contentSize > maxFrameContentSize:
return streamDecompress(dst, src, dd)
case uint64(C.ZSTD_CONTENTSIZE_ERROR):
case contentSize == C.ZSTD_CONTENTSIZE_ERROR:
return dst, fmt.Errorf("cannot decompress invalid src")
}
decompressBound++
decompressBound := int(contentSize) + 1

if n := dstLen + decompressBound - cap(dst); n > 0 {
// This should be optimized since go 1.11 - see https://golang.org/doc/go1.11#performance-compiler.
Expand All @@ -287,24 +303,28 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte
}

func decompressInternal(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) C.size_t {
var n C.size_t
var (
dstHdr = (*reflect.SliceHeader)(noescape(unsafe.Pointer(&dst)))
srcHdr = (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src)))
n C.size_t
)
if dd != nil {
n = C.ZSTD_decompress_usingDDict_wrapper(
C.uintptr_t(uintptr(unsafe.Pointer(dctxDict.dctx))),
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
unsafe.Pointer(dctxDict.dctx),
unsafe.Pointer(dstHdr.Data),
C.size_t(cap(dst)),
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
unsafe.Pointer(srcHdr.Data),
C.size_t(len(src)),
C.uintptr_t(uintptr(unsafe.Pointer(dd.p))))
unsafe.Pointer(dd.p))
} else {
n = C.ZSTD_decompressDCtx_wrapper(
C.uintptr_t(uintptr(unsafe.Pointer(dctx.dctx))),
C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
unsafe.Pointer(dctx.dctx),
unsafe.Pointer(dstHdr.Data),
C.size_t(cap(dst)),
C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
unsafe.Pointer(srcHdr.Data),
C.size_t(len(src)))
}
// Prevent from GC'ing of dst and src during CGO calls above.
// Prevent from GC'ing of dst and src during CGO call above.
runtime.KeepAlive(dst)
runtime.KeepAlive(src)
return n
Expand All @@ -317,13 +337,17 @@ func errStr(result C.size_t) string {
}

func ensureNoError(funcName string, result C.size_t) {
if zstdIsError(result) {
panic(fmt.Errorf("BUG: unexpected error in %s: %s", funcName, errStr(result)))
}
}

func zstdIsError(result C.size_t) bool {
if int(result) >= 0 {
// Fast path - avoid calling C function.
return
}
if C.ZSTD_getErrorCode(result) != 0 {
panic(fmt.Errorf("BUG: unexpected error in %s: %s", funcName, errStr(result)))
return false
}
return C.ZSTD_isError(result) != 0
}

func streamDecompress(dst, src []byte, dd *DDict) ([]byte, error) {
Expand Down
59 changes: 59 additions & 0 deletions gozstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/hex"
"fmt"
"io"
"math/rand"
"runtime"
"strings"
Expand Down Expand Up @@ -54,6 +55,22 @@ func TestDecompressSmallBlockWithoutSingleSegmentFlag(t *testing.T) {
})
}

func TestCompressEmpty(t *testing.T) {
var dst [64]byte
res := Compress(dst[:0], nil)
if len(res) > 0 {
t.Fatalf("unexpected non-empty compressed frame: %X", res)
}
}

func TestDecompressTooLarge(t *testing.T) {
src := []byte{40, 181, 47, 253, 228, 122, 118, 105, 67, 140, 234, 85, 20, 159, 67}
_, err := Decompress(nil, src)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the test name I gather that the error here is that the decompressed size is too larger for the dst buf (nil)? It would be a bit easier to read if the dst buf was non-nil, like maybe 1 byte or something

if err == nil {
t.Fatalf("expecting error when decompressing malformed frame")
}
}

func mustUnhex(dataHex string) []byte {
data, err := hex.DecodeString(dataHex)
if err != nil {
Expand All @@ -62,6 +79,48 @@ func mustUnhex(dataHex string) []byte {
return data
}

func TestCompressWithStackMove(t *testing.T) {
var srcBuf [96]byte

n, err := io.ReadFull(rand.New(rand.NewSource(time.Now().Unix())), srcBuf[:])
if err != nil {
t.Fatalf("cannot fill srcBuf with random data: %s", err)
}

// We're running this twice, because the first run will allocate
// objects in sync.Pool, calls to which extend the stack, and the second
// run can skip those allocations and extend the stack right before
// the CGO call.
// Note that this test might require some go:nosplit annotations
// to force the stack move to happen exactly before the CGO call.
for i := 0; i < 2; i++ {
ch := make(chan struct{})
go func() {
defer close(ch)

var dstBuf [1416]byte

res := Compress(dstBuf[:0], srcBuf[:n])

// make a copy of the result, so the original can remain on the stack
compressedCpy := make([]byte, len(res))
copy(compressedCpy, res)

orig, err := Decompress(nil, compressedCpy)
if err != nil {
panic(fmt.Errorf("cannot decompress: %s", err))
}
if !bytes.Equal(orig, srcBuf[:n]) {
panic(fmt.Errorf("unexpected decompressed data; got %q; want %q", orig, srcBuf[:n]))
}
}()
// wait for the goroutine to finish
<-ch
}

runtime.GC()
}

func TestCompressDecompressDistinctConcurrentDicts(t *testing.T) {
// Build multiple distinct dicts.
var cdicts []*CDict
Expand Down
Loading