Skip to content

Commit

Permalink
Integrate review comments
Browse files Browse the repository at this point in the history
* String typed enum

Signed-off-by: Manuel Rüger <[email protected]>
  • Loading branch information
mrueg committed Jun 3, 2024
1 parent adcec1e commit 512d614
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 63 deletions.
109 changes: 56 additions & 53 deletions prometheus/promhttp/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,14 @@ const (
processStartTimeHeader = "Process-Start-Time-Unix"
)

type Compression int
type Compression string

const (
Identity Compression = iota
Gzip
Zstd
Identity Compression = "identity"
Gzip Compression = "gzip"
Zstd Compression = "zstd"
)

var compressions = [...]string{
"identity",
"gzip",
"zstd",
}

func (c Compression) String() string {
return compressions[c]
}

var defaultCompressionFormats = []Compression{Identity, Gzip, Zstd}

var gzipPool = sync.Pool{
Expand Down Expand Up @@ -143,6 +133,18 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
}
}

// Select all supported compression formats
var compressions []string
if !opts.DisableCompression {
offers := defaultCompressionFormats
if len(opts.OfferedCompressions) > 0 {
offers = opts.OfferedCompressions
}
for _, comp := range offers {
compressions = append(compressions, string(comp))
}
}

h := http.HandlerFunc(func(rsp http.ResponseWriter, req *http.Request) {
if !opts.ProcessStartTime.IsZero() {
rsp.Header().Set(processStartTimeHeader, strconv.FormatInt(opts.ProcessStartTime.Unix(), 10))
Expand Down Expand Up @@ -188,13 +190,17 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
}
rsp.Header().Set(contentTypeHeader, string(contentType))

w, err := GetWriter(req, rsp, opts.DisableCompression, opts.OfferedCompressions)
w, encodingHeader, err := NegotiateEncodingWriter(req, rsp, opts.DisableCompression, compressions)
if err != nil {
if opts.ErrorLog != nil {
opts.ErrorLog.Println("error getting writer", err)
}
// Since the writer received from NegotiateEncodingWriter will be nil, in case there's an error, we set it here
w = io.Writer(rsp)
}

rsp.Header().Set(contentEncodingHeader, encodingHeader)

enc := expfmt.NewEncoder(w, contentType)

// handleError handles the error according to opts.ErrorHandling
Expand Down Expand Up @@ -419,48 +425,45 @@ func httpError(rsp http.ResponseWriter, err error) {
)
}

func GetWriter(r *http.Request, rsp http.ResponseWriter, disableCompression bool, offeredCompressions []Compression) (io.Writer, error) {
w := io.Writer(rsp)
rsp.Header().Set(contentEncodingHeader, "identity")
if !disableCompression {
offers := defaultCompressionFormats
if len(offeredCompressions) > 0 {
offers = offeredCompressions
}
var compressions []string
for _, comp := range offers {
compressions = append(compressions, comp.String())
// NegotiateEncodingWriter reads the Accept-Encoding header from a request and
// selects the right compression based on an allow-list of supported
// compressions. It returns a writer implementing the compression and an the
// correct value that the caller can set in the response header.
func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression bool, compressions []string) (_ io.Writer, encodingHeaderValue string, _ error) {
w := rw

if disableCompression {
return w, string(Identity), nil
}

// TODO(mrueg): Replace internal/github.com/gddo once https://github.com/golang/go/issues/19307 is implemented.
compression := httputil.NegotiateContentEncoding(r, compressions)

switch compression {
case "zstd":
// TODO(mrueg): Replace klauspost/compress with stdlib implementation once https://github.com/golang/go/issues/62513 is implemented.
z, err := zstd.NewWriter(rw, zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
return nil, "", err
}
// TODO(mrueg): Replace internal/github.com/gddo once https://github.com/golang/go/issues/19307 is implemented.
compression := httputil.NegotiateContentEncoding(r, compressions)
switch compression {
case "zstd":
rsp.Header().Set(contentEncodingHeader, "zstd")
// TODO(mrueg): Replace klauspost/compress with stdlib implementation once https://github.com/golang/go/issues/62513 is implemented.
z, err := zstd.NewWriter(rsp, zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
return nil, err
}

z.Reset(w)
defer z.Close()
z.Reset(w)
defer z.Close()

w = z
case "gzip":
rsp.Header().Set(contentEncodingHeader, "gzip")
gz := gzipPool.Get().(*gzip.Writer)
defer gzipPool.Put(gz)
w = z
case "gzip":
gz := gzipPool.Get().(*gzip.Writer)
defer gzipPool.Put(gz)

gz.Reset(w)
defer gz.Close()
gz.Reset(w)
defer gz.Close()

w = gz
case "identity":
// This means the content is not compressed.
default:
// The content encoding was not implemented yet.
return w, fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
}
w = gz
case "identity":
// This means the content is not compressed.
default:
// The content encoding was not implemented yet.
return nil, "", fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
}
return w, nil
return w, compression, nil
}
26 changes: 16 additions & 10 deletions prometheus/promhttp/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,51 +332,57 @@ func TestHandlerTimeout(t *testing.T) {
close(c.Block) // To not leak a goroutine.
}

func TestGetWriter(t *testing.T) {
func TestNegotiateEncodingWriter(t *testing.T) {
var defaultCompressions []string

for _, comp := range defaultCompressionFormats {
defaultCompressions = append(defaultCompressions, string(comp))
}

testCases := []struct {
name string
disableCompression bool
offeredCompressions []Compression
offeredCompressions []string
acceptEncoding string
expectedCompression string
err error
}{
{
name: "test without compression enabled",
disableCompression: true,
offeredCompressions: defaultCompressionFormats,
offeredCompressions: defaultCompressions,
acceptEncoding: "",
expectedCompression: "identity",
err: nil,
},
{
name: "test with compression enabled with empty accept-encoding header",
disableCompression: false,
offeredCompressions: defaultCompressionFormats,
offeredCompressions: defaultCompressions,
acceptEncoding: "",
expectedCompression: "identity",
err: nil,
},
{
name: "test with gzip compression requested",
disableCompression: false,
offeredCompressions: defaultCompressionFormats,
offeredCompressions: defaultCompressions,
acceptEncoding: "gzip",
expectedCompression: "gzip",
err: nil,
},
{
name: "test with gzip, zstd compression requested",
disableCompression: false,
offeredCompressions: defaultCompressionFormats,
offeredCompressions: defaultCompressions,
acceptEncoding: "gzip,zstd",
expectedCompression: "gzip",
err: nil,
},
{
name: "test with zstd, gzip compression requested",
disableCompression: false,
offeredCompressions: defaultCompressionFormats,
offeredCompressions: defaultCompressions,
acceptEncoding: "zstd,gzip",
expectedCompression: "gzip",
err: nil,
Expand All @@ -387,14 +393,14 @@ func TestGetWriter(t *testing.T) {
request, _ := http.NewRequest("GET", "/", nil)
request.Header.Add(acceptEncodingHeader, test.acceptEncoding)
rr := httptest.NewRecorder()
_, err := GetWriter(request, rr, test.disableCompression, test.offeredCompressions)
_, encodingHeader, err := NegotiateEncodingWriter(request, rr, test.disableCompression, test.offeredCompressions)

if !errors.Is(err, test.err) {
t.Errorf("got error: %v, expected: %v", err, test.err)
}

if rr.Header().Get(contentEncodingHeader) != test.expectedCompression {
t.Errorf("got different compression type: %v, expected: %v", rr.Header().Get(contentEncodingHeader), test.expectedCompression)
if encodingHeader != test.expectedCompression {
t.Errorf("got different compression type: %v, expected: %v", encodingHeader, test.expectedCompression)
}
}
}
Expand Down

0 comments on commit 512d614

Please sign in to comment.