diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a8f6bd0..7e268fd 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -14,9 +14,11 @@ // The optional 'workspaceFolder' property is the path VS Code should open by default when // connected. This is typically a file mount in .devcontainer/docker-compose.yml "workspaceFolder": "/app", - // Set *default* container specific settings.json values on container create. + // All containers should stop if we close / reload the VSCode window. + "shutdownAction": "stopCompose", "customizations": { "vscode": { + // Set *default* container specific settings.json values on container create. "settings": { // https://github.com/golang/tools/blob/master/gopls/doc/vscode.md#vscode "go.useLanguageServer": true, @@ -45,6 +47,8 @@ // DISABLED, done via "staticcheck": false, }, + // https://code.visualstudio.com/docs/languages/go#_intellisense + "go.autocompleteUnimportedPackages": true, // https://github.com/golangci/golangci-lint#editor-integration "go.lintTool": "golangci-lint", "go.lintFlags": [ @@ -69,30 +73,6 @@ }, // ensure that the pgFormatter VSCode extension uses the pgFormatter that comes preinstalled in the Dockerfile "pgFormatter.pgFormatterPath": "/usr/local/bin/pg_format" - // "go.lintOnSave": "workspace" - // general build settings in sync with our makefile - // "go.buildFlags": [ - // "-o", - // "bin/app" - // ] - // "sqltools.connections": [ - // { - // "database": "sample", - // "dialect": "PostgreSQL", - // "name": "postgres", - // "password": "9bed16f749d74a3c8bfbced18a7647f5", - // "port": 5432, - // "server": "postgres", - // "username": "dbuser" - // } - // ], - // "sqltools.autoConnectTo": [ - // "postgres" - // ], - // // only use pg_format to actually format! - // "sqltools.formatLanguages": [], - // "sqltools.telemetry": false, - // "sqltools.autoOpenSessionFiles": false }, // Add the IDs of extensions you want installed when the container is created. "extensions": [ @@ -100,12 +80,12 @@ "golang.go", "bradymholt.pgformatter", // optional: - // "766b.go-outliner", + "42crunch.vscode-openapi", "heaths.vscode-guid", "bungcip.better-toml", "eamodio.gitlens", - "casualjim.gotemplate" - // "mtxr.sqltools", + "casualjim.gotemplate", + "yzhang.markdown-all-in-one" ] } }, @@ -115,6 +95,7 @@ // "shutdownAction": "none", // Uncomment the next line to run commands after the container is created - for example installing git. "postCreateCommand": "go version", + // "postCreateCommand": "apt-get update && apt-get install -y git", // Uncomment to connect as a non-root user. See https://aka.ms/vscode-remote/containers/non-root. - "remoteUser": "development" + // "remoteUser": "" } \ No newline at end of file diff --git a/.drone.yml b/.drone.yml index 20a5c9a..ec24175 100644 --- a/.drone.yml +++ b/.drone.yml @@ -30,7 +30,7 @@ alias: - &IMAGE_DEPLOY_ID ${DRONE_REPO,,}:${DRONE_COMMIT_SHA} # Defines which branches will trigger a docker image push our Google Cloud Registry (tags are always published) - - &GCR_PUBLISH_BRANCHES [dev, master, aj/pooling-improvements] + - &GCR_PUBLISH_BRANCHES [dev, master, aj/pooling-improvements, mr/aj-review] # Docker registry publish default settings - &GCR_REGISTRY_SETTINGS @@ -133,7 +133,7 @@ pipeline: environment: IMAGE_TAG: *IMAGE_BUILDER_ID commands: - - "docker build --target builder-integresql --compress -t $${IMAGE_TAG} ." + - "docker build --target builder --compress -t $${IMAGE_TAG} ." <<: *WHEN_BUILD_EVENT "docker build (target integresql)": diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..af28163 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,29 @@ +linters: + enable: + # https://github.com/golangci/golangci-lint#enabled-by-default-linters + # Additional linters you want to activate may be specified here... + + # --- + # https://github.com/mgechev/revive + # replacement for the now deprecated official golint linter, see https://github.com/golang/go/issues/38968 + - revive + + # --- + # https://github.com/maratori/testpackage + # used to enforce blackbox testing + - testpackage + + # --- + # https://github.com/securego/gosec + # inspects source code for security problems by scanning the Go AST. + - gosec + + # --- + # https://github.com/sivchari/tenv + # prefer t.Setenv instead of os.Setenv within test code. + - tenv + + # --- + # https://github.com/polyfloyd/go-errorlint + # ensure we are comparing errors via errors.Is, types/values via errors.As and wrap errors with %w. + - errorlint diff --git a/Dockerfile b/Dockerfile index 6d23916..b7d3416 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,23 +18,55 @@ ENV MAKEFLAGS "-j 8 --no-print-directory" # e.g. stretch=>stretch-pgdg, buster=>buster-pgdg, bullseye=>bullseye-pgdg RUN echo "deb http://apt.postgresql.org/pub/repos/apt/ bullseye-pgdg main" \ | tee /etc/apt/sources.list.d/pgdg.list \ - && apt install curl ca-certificates gnupg \ - && curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor | tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg >/dev/null - + && wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc \ + | apt-key add - # Install required system dependencies RUN apt-get update \ && apt-get install -y \ + # + # Mandadory minimal linux packages + # Installed at development stage and app stage + # Do not forget to add mandadory linux packages to the final app Dockerfile stage below! + # + # -- START MANDADORY -- + ca-certificates \ + # --- END MANDADORY --- + # + # Development specific packages + # Only installed at development stage and NOT available in the final Docker stage + # based upon + # https://github.com/microsoft/vscode-remote-try-go/blob/master/.devcontainer/Dockerfile + # https://raw.githubusercontent.com/microsoft/vscode-dev-containers/master/script-library/common-debian.sh + # + # icu-devtools: https://stackoverflow.com/questions/58736399/how-to-get-vscode-liveshare-extension-working-when-running-inside-vscode-remote + # graphviz: https://github.com/google/pprof#building-pprof + # -- START DEVELOPMENT -- + apt-utils \ + dialog \ + openssh-client \ + less \ + iproute2 \ + procps \ + lsb-release \ locales \ sudo \ bash-completion \ bsdmainutils \ + graphviz \ + xz-utils \ postgresql-client-12 \ + icu-devtools \ + tmux \ + rsync \ + # --- END DEVELOPMENT --- + # && apt-get clean \ && rm -rf /var/lib/apt/lists/* -# vscode support: LANG must be supported, requires installing the locale package first -# see https://github.com/Microsoft/vscode/issues/58015 +# env/vscode support: LANG must be supported, requires installing the locale package first +# https://github.com/Microsoft/vscode/issues/58015 +# https://stackoverflow.com/questions/28405902/how-to-set-the-locale-inside-a-debian-ubuntu-docker-container RUN sed -i -e 's/# en_US.UTF-8 UTF-8/en_US.UTF-8 UTF-8/' /etc/locale.gen && \ dpkg-reconfigure --frontend=noninteractive locales && \ update-locale LANG=en_US.UTF-8 @@ -82,6 +114,25 @@ RUN ARCH="$(arch | sed s/aarch64/arm64/ | sed s/x86_64/amd64/)" \ # https://github.com/uw-labs/lichen/tags RUN go install github.com/uw-labs/lichen@v0.1.7 +# watchexec +# https://github.com/watchexec/watchexec/releases +RUN mkdir -p /tmp/watchexec \ + && cd /tmp/watchexec \ + && wget https://github.com/watchexec/watchexec/releases/download/v1.20.6/watchexec-1.20.6-$(arch)-unknown-linux-musl.tar.xz \ + && tar xf watchexec-1.20.6-$(arch)-unknown-linux-musl.tar.xz \ + && cp watchexec-1.20.6-$(arch)-unknown-linux-musl/watchexec /usr/local/bin/watchexec \ + && rm -rf /tmp/watchexec + +# yq +# https://github.com/mikefarah/yq/releases +RUN mkdir -p /tmp/yq \ + && cd /tmp/yq \ + && ARCH="$(arch | sed s/aarch64/arm64/ | sed s/x86_64/amd64/)" \ + && wget "https://github.com/mikefarah/yq/releases/download/v4.30.5/yq_linux_${ARCH}.tar.gz" \ + && tar xzf "yq_linux_${ARCH}.tar.gz" \ + && cp "yq_linux_${ARCH}" /usr/local/bin/yq \ + && rm -rf /tmp/yq + # linux permissions / vscode support: Add user to avoid linux file permission issues # Detail: Inside the container, any mounted files/folders will have the exact same permissions # as outside the container - including the owner user ID (UID) and group ID (GID). @@ -100,7 +151,6 @@ RUN groupadd --gid $USER_GID $USERNAME \ && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ && chmod 0440 /etc/sudoers.d/$USERNAME - # vscode support: cached extensions install directory # https://code.visualstudio.com/docs/remote/containers-advanced#_avoiding-extension-reinstalls-on-container-rebuild RUN mkdir -p /home/$USERNAME/.vscode-server/extensions \ @@ -113,7 +163,6 @@ RUN mkdir -p /home/$USERNAME/.vscode-server/extensions \ # Note that this should be the final step after installing all build deps RUN mkdir -p /$GOPATH/pkg && chown -R $USERNAME /$GOPATH - # $GOBIN is where our own compiled binaries will live and other go.mod / VSCode binaries will be installed. # It should always come AFTER our other $PATH segments and should be earliest targeted in stage "builder", # as /app/bin will the shadowed by a volume mount via docker-compose! @@ -133,15 +182,11 @@ COPY Makefile /app/Makefile COPY go.mod /app/go.mod COPY go.sum /app/go.sum COPY tools.go /app/tools.go -RUN make modules && make tools +RUN make modules +COPY tools.go /app/tools.go +RUN make tools COPY . /app/ - -### ----------------------- -# --- Stage: builder-integresql -### ----------------------- - -FROM builder as builder-integresql -RUN make build +RUN make go-build ### ----------------------- # --- Stage: integresql @@ -152,7 +197,7 @@ RUN make build # The :debug image provides a busybox shell to enter. # https://github.com/GoogleContainerTools/distroless#debug-images FROM gcr.io/distroless/base-debian11:debug as integresql -COPY --from=builder-integresql /app/bin/integresql / +COPY --from=builder /app/bin/integresql / # Note that cmd is not supported with these kind of images, no shell included # see https://github.com/GoogleContainerTools/distroless/issues/62 # and https://github.com/GoogleContainerTools/distroless#entrypoints diff --git a/Makefile b/Makefile index 284edf2..4e4701c 100644 --- a/Makefile +++ b/Makefile @@ -4,8 +4,8 @@ # first is default target when running "make" without args build: ##- Default 'make' target: go-format, go-build and lint. - @$(MAKE) format - @$(MAKE) gobuild + @$(MAKE) go-format + @$(MAKE) go-build @$(MAKE) lint # useful to ensure that everything gets resetuped from scratch @@ -22,14 +22,16 @@ info-go: ##- (opt) Prints go.mod updates, module-name and current go version. @go version >> tmp/.info-go @cat tmp/.info-go -format: - go fmt +lint: go-lint ##- Runs golangci-lint and make check-*. -gobuild: - go build -o bin/integresql ./cmd/server +go-format: ##- (opt) Runs go format. + go fmt ./... -lint: - golangci-lint run --fast +go-build: ##- (opt) Runs go build. + go build -ldflags $(LDFLAGS) -o bin/integresql ./cmd/server + +go-lint: ##- (opt) Runs golangci-lint. + golangci-lint run --timeout 5m bench: ##- Run tests, output by package, print coverage. @go test -benchmem=false -run=./... -bench . github.com/allaboutapps/integresql/tests -race -count=4 -v @@ -45,7 +47,10 @@ test: ##- Run tests, output by package, print coverage. # note that we explicitly don't want to use a -coverpkg=./... option, per pkg coverage take precedence go-test-by-pkg: ##- (opt) Run tests, output by package. - gotestsum --format pkgname-and-test-fails --jsonfile /tmp/test.log -- -race -cover -count=1 -coverprofile=/tmp/coverage.out ./... + gotestsum --format pkgname-and-test-fails --format-hide-empty-pkg --jsonfile /tmp/test.log -- -race -cover -count=1 -coverprofile=/tmp/coverage.out ./... + +go-test-by-name: ##- (opt) Run tests, output by testname. + gotestsum --format testname --jsonfile /tmp/test.log -- -race -cover -count=1 -coverprofile=/tmp/coverage.out ./... go-test-print-coverage: ##- (opt) Print overall test coverage (must be done after running tests). @printf "coverage " @@ -176,10 +181,6 @@ LDFLAGS = $(eval LDFLAGS := "\ # required to ensure make fails if one recipe fails (even on parallel jobs) and on pipefails .ONESHELL: -# # normal POSIX bash shell mode -# SHELL = /bin/bash -# .SHELLFLAGS = -cEeuo pipefail - -# wrapped make time tracing shell, use it via MAKE_TRACE_TIME=true make -# SHELL = /bin/rksh -# .SHELLFLAGS = $@ \ No newline at end of file +# normal POSIX bash shell mode +SHELL = /bin/bash +.SHELLFLAGS = -cEeuo pipefail diff --git a/cmd/server/main.go b/cmd/server/main.go index f23b3aa..9ab58f2 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,7 +2,7 @@ package main import ( "context" - "log" + "errors" "net/http" "os" "os/signal" @@ -10,21 +10,41 @@ import ( "time" "github.com/allaboutapps/integresql/internal/api" + "github.com/allaboutapps/integresql/internal/config" "github.com/allaboutapps/integresql/internal/router" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) func main() { - s := api.DefaultServerFromEnv() + + cfg := api.DefaultServerConfigFromEnv() + + zerolog.TimeFieldFormat = time.RFC3339Nano + zerolog.SetGlobalLevel(cfg.Logger.Level) + if cfg.Logger.PrettyPrintConsole { + log.Logger = log.Output(zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) { + w.TimeFormat = "15:04:05" + })) + } + + log.Info().Str("version", config.GetFormattedBuildArgs()).Msg("starting...") + + s := api.NewServer(cfg) if err := s.InitManager(context.Background()); err != nil { - log.Fatalf("Failed to initialize manager: %v", err) + log.Fatal().Err(err).Msg("Failed to initialize manager") } router.Init(s) go func() { if err := s.Start(); err != nil { - log.Fatalf("Failed to start server: %v", err) + if errors.Is(err, http.ErrServerClosed) { + log.Info().Msg("Server closed") + } else { + log.Fatal().Err(err).Msg("Failed to start server") + } } }() @@ -35,7 +55,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - if err := s.Shutdown(ctx); err != nil && err != http.ErrServerClosed { - log.Fatalf("Failed to gracefully shut down server: %v", err) + if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatal().Err(err).Msg("Failed to gracefully shut down server") } } diff --git a/docker-compose.yml b/docker-compose.yml index 75c3afd..70acc3e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,12 +5,26 @@ services: build: context: . target: development - ports: - - "5000:5000" + # ports: + # - "5000:5000" working_dir: /app + # linux permissions / vscode support: we must explicitly run as the development user + user: development volumes: - - .:/app #:delegated - # - ./.pkg:/go/pkg # enable this to reuse the pkg cache + # mount working directory + # https://code.visualstudio.com/docs/remote/containers-advanced#_update-the-mount-consistency-to-delegated-for-macos + # https://docs.docker.com/docker-for-mac/osxfs-caching/#delegated + # the container’s view is authoritative (permit delays before updates on the container appear in the host) + - .:/app:delegated + + # mount cached go pkg downloads + - go-pkg:/go/pkg + + # mount cached vscode container extensions + # https://code.visualstudio.com/docs/remote/containers-advanced#_avoiding-extension-reinstalls-on-container-rebuild + - vscode-extensions:/home/development/.vscode-server/extensions + - vscode-extensions-insiders:/home/development/.vscode-server-insiders/extensions + depends_on: - postgres environment: &SERVICE_ENV @@ -28,7 +42,13 @@ services: - seccomp:unconfined # Overrides default command so things don't shut down after the process ends. - command: /bin/sh -c "while sleep 1000; do :; done" + # Overrides default command so things don't shut down after the process ends. + command: + - /bin/sh + - -c + - | + git config --global --add safe.directory /app + while sleep 1000; do :; done postgres: image: postgres:12.4-alpine # should be the same version as used in .drone.yml, Dockerfile and live @@ -46,3 +66,11 @@ services: volumes: pgvolume: # declare a named volume to persist DB data + + # go: go mod cached downloads + go-pkg: + + # vscode: Avoiding extension reinstalls on container rebuild + # https://code.visualstudio.com/docs/remote/containers-advanced#_avoiding-extension-reinstalls-on-container-rebuild + vscode-extensions: + vscode-extensions-insiders: diff --git a/go.mod b/go.mod index 4f748c6..b4e6479 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/google/uuid v1.3.0 github.com/labstack/echo/v4 v4.10.2 github.com/lib/pq v1.10.9 + github.com/rs/zerolog v1.28.0 github.com/stretchr/testify v1.8.4 golang.org/x/sync v0.3.0 ) diff --git a/go.sum b/go.sum index 5a270e2..3701fcb 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ +github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -17,14 +19,19 @@ github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3 github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= +github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= diff --git a/internal/api/middleware/logger.go b/internal/api/middleware/logger.go new file mode 100644 index 0000000..7135b5b --- /dev/null +++ b/internal/api/middleware/logger.go @@ -0,0 +1,307 @@ +package middleware + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/allaboutapps/integresql/pkg/util" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +// RequestBodyLogSkipper defines a function to skip logging certain request bodies. +// Returning true skips logging the payload of the request. +type RequestBodyLogSkipper func(req *http.Request) bool + +// DefaultRequestBodyLogSkipper returns true for all requests with Content-Type +// application/x-www-form-urlencoded or multipart/form-data as those might contain +// binary or URL-encoded file uploads unfit for logging purposes. +func DefaultRequestBodyLogSkipper(req *http.Request) bool { + contentType := req.Header.Get(echo.HeaderContentType) + switch { + case strings.HasPrefix(contentType, echo.MIMEApplicationForm), + strings.HasPrefix(contentType, echo.MIMEMultipartForm): + return true + default: + return false + } +} + +// ResponseBodyLogSkipper defines a function to skip logging certain response bodies. +// Returning true skips logging the payload of the response. +type ResponseBodyLogSkipper func(req *http.Request, res *echo.Response) bool + +// DefaultResponseBodyLogSkipper returns false for all responses with Content-Type +// application/json, preventing logging for all other types of payloads as those +// might contain binary or URL-encoded data unfit for logging purposes. +func DefaultResponseBodyLogSkipper(_ *http.Request, res *echo.Response) bool { + contentType := res.Header().Get(echo.HeaderContentType) + switch { + case strings.HasPrefix(contentType, echo.MIMEApplicationJSON): + return false + default: + return true + } +} + +// BodyLogReplacer defines a function to replace certain parts of a body before logging it, +// mainly used to strip sensitive information from a request or response payload. +// The []byte returned should contain a sanitized payload ready for logging. +type BodyLogReplacer func(body []byte) []byte + +// DefaultBodyLogReplacer returns the body received without any modifications. +func DefaultBodyLogReplacer(body []byte) []byte { + return body +} + +// HeaderLogReplacer defines a function to replace certain parts of a header before logging it, +// mainly used to strip sensitive information from a request or response header. +// The http.Header returned should be a sanitized copy of the original header as not to modify +// the request or response while logging. +type HeaderLogReplacer func(header http.Header) http.Header + +// DefaultHeaderLogReplacer replaces all Authorization, X-CSRF-Token and Proxy-Authorization +// header entries with a redacted string, indicating their presence without revealing actual, +// potentially sensitive values in the logs. +func DefaultHeaderLogReplacer(header http.Header) http.Header { + sanitizedHeader := http.Header{} + + for k, vv := range header { + shouldRedact := strings.EqualFold(k, echo.HeaderAuthorization) || + strings.EqualFold(k, echo.HeaderXCSRFToken) || + strings.EqualFold(k, "Proxy-Authorization") + + for _, v := range vv { + if shouldRedact { + sanitizedHeader.Add(k, "*****REDACTED*****") + } else { + sanitizedHeader.Add(k, v) + } + } + } + + return sanitizedHeader +} + +// QueryLogReplacer defines a function to replace certain parts of a URL query before logging it, +// mainly used to strip sensitive information from a request query. +// The url.Values returned should be a sanitized copy of the original query as not to modify the +// request while logging. +type QueryLogReplacer func(query url.Values) url.Values + +// DefaultQueryLogReplacer returns the query received without any modifications. +func DefaultQueryLogReplacer(query url.Values) url.Values { + return query +} + +var ( + DefaultLoggerConfig = LoggerConfig{ + Skipper: middleware.DefaultSkipper, + Level: zerolog.DebugLevel, + LogRequestBody: false, + LogRequestHeader: false, + LogRequestQuery: false, + RequestBodyLogSkipper: DefaultRequestBodyLogSkipper, + RequestBodyLogReplacer: DefaultBodyLogReplacer, + RequestHeaderLogReplacer: DefaultHeaderLogReplacer, + RequestQueryLogReplacer: DefaultQueryLogReplacer, + LogResponseBody: false, + LogResponseHeader: false, + ResponseBodyLogSkipper: DefaultResponseBodyLogSkipper, + ResponseBodyLogReplacer: DefaultBodyLogReplacer, + } +) + +type LoggerConfig struct { + Skipper middleware.Skipper + Level zerolog.Level + LogRequestBody bool + LogRequestHeader bool + LogRequestQuery bool + RequestBodyLogSkipper RequestBodyLogSkipper + RequestBodyLogReplacer BodyLogReplacer + RequestHeaderLogReplacer HeaderLogReplacer + RequestQueryLogReplacer QueryLogReplacer + LogResponseBody bool + LogResponseHeader bool + ResponseBodyLogSkipper ResponseBodyLogSkipper + ResponseBodyLogReplacer BodyLogReplacer + ResponseHeaderLogReplacer HeaderLogReplacer +} + +// Logger with default logger output and configuration +func Logger() echo.MiddlewareFunc { + return LoggerWithConfig(DefaultLoggerConfig, nil) +} + +// LoggerWithConfig returns a new MiddlewareFunc which creates a logger with the desired configuration. +// If output is set to nil, the default output is used. If more output params are provided, the first is being used. +func LoggerWithConfig(config LoggerConfig, output ...io.Writer) echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultLoggerConfig.Skipper + } + if config.RequestBodyLogSkipper == nil { + config.RequestBodyLogSkipper = DefaultRequestBodyLogSkipper + } + if config.RequestBodyLogReplacer == nil { + config.RequestBodyLogReplacer = DefaultBodyLogReplacer + } + if config.RequestHeaderLogReplacer == nil { + config.RequestHeaderLogReplacer = DefaultHeaderLogReplacer + } + if config.RequestQueryLogReplacer == nil { + config.RequestQueryLogReplacer = DefaultQueryLogReplacer + } + if config.ResponseBodyLogSkipper == nil { + config.ResponseBodyLogSkipper = DefaultResponseBodyLogSkipper + } + if config.ResponseBodyLogReplacer == nil { + config.ResponseBodyLogReplacer = DefaultBodyLogReplacer + } + if config.ResponseHeaderLogReplacer == nil { + config.ResponseHeaderLogReplacer = DefaultHeaderLogReplacer + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + req := c.Request() + res := c.Response() + + id := req.Header.Get(echo.HeaderXRequestID) + if len(id) == 0 { + id = res.Header().Get(echo.HeaderXRequestID) + } + + in := req.Header.Get(echo.HeaderContentLength) + if len(in) == 0 { + in = "0" + } + + l := log.With(). + Dict("req", zerolog.Dict(). + Str("id", id). + Str("host", req.Host). + Str("method", req.Method). + Str("url", req.URL.String()). + Str("bytes_in", in), + ).Logger() + + if len(output) > 0 { + l = l.Output(output[0]) + } + + le := l.WithLevel(config.Level) + req = req.WithContext(l.WithContext(context.WithValue(req.Context(), util.CTXKeyRequestID, id))) + + if config.LogRequestBody && !config.RequestBodyLogSkipper(req) { + var reqBody []byte + var err error + if req.Body != nil { + reqBody, err = io.ReadAll(req.Body) + if err != nil { + l.Error().Err(err).Msg("Failed to read body while logging request") + return err + } + + req.Body = io.NopCloser(bytes.NewBuffer(reqBody)) + } + + le = le.Bytes("req_body", config.RequestBodyLogReplacer(reqBody)) + } + if config.LogRequestHeader { + header := zerolog.Dict() + for k, v := range config.RequestHeaderLogReplacer(req.Header) { + header.Strs(k, v) + } + + le = le.Dict("req_header", header) + } + if config.LogRequestQuery { + query := zerolog.Dict() + for k, v := range req.URL.Query() { + query.Strs(k, v) + } + + le = le.Dict("req_query", query) + } + + le.Msg("Request received") + + c.SetRequest(req) + + var resBody bytes.Buffer + if config.LogResponseBody { + mw := io.MultiWriter(res.Writer, &resBody) + writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: res.Writer} + res.Writer = writer + } + + start := time.Now() + err := next(c) + if err != nil { + c.Error(err) + } + stop := time.Now() + + // Retrieve logger from context again since other middlewares might have enhanced it + ll := util.LogFromEchoContext(c) + lle := ll.WithLevel(config.Level). + Dict("res", zerolog.Dict(). + Int("status", res.Status). + Int64("bytes_out", res.Size). + TimeDiff("duration_ms", stop, start). + Err(err), + ) + + if config.LogResponseBody && !config.ResponseBodyLogSkipper(req, res) { + lle = lle.Bytes("res_body", config.ResponseBodyLogReplacer(resBody.Bytes())) + } + if config.LogResponseHeader { + header := zerolog.Dict() + for k, v := range config.ResponseHeaderLogReplacer(res.Header()) { + header.Strs(k, v) + } + + lle = lle.Dict("res_header", header) + } + + lle.Msg("Response sent") + + return nil + } + } +} + +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w *bodyDumpResponseWriter) WriteHeader(code int) { + w.ResponseWriter.WriteHeader(code) +} + +func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} + +func (w *bodyDumpResponseWriter) Flush() { + w.ResponseWriter.(http.Flusher).Flush() +} + +func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return w.ResponseWriter.(http.Hijacker).Hijack() +} diff --git a/internal/api/server.go b/internal/api/server.go index 23f6450..e4a4689 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -8,6 +8,7 @@ import ( "net" "time" + // #nosec G108 - pprof handlers (conditionally made available via http.DefaultServeMux within router) _ "net/http/pprof" "github.com/allaboutapps/integresql/pkg/manager" diff --git a/internal/api/server_config.go b/internal/api/server_config.go index 41194f9..316cace 100644 --- a/internal/api/server_config.go +++ b/internal/api/server_config.go @@ -1,17 +1,61 @@ package api -import "github.com/allaboutapps/integresql/pkg/util" +import ( + "github.com/allaboutapps/integresql/pkg/util" + "github.com/rs/zerolog" +) type ServerConfig struct { Address string Port int DebugEndpoints bool + Logger LoggerConfig + Echo EchoConfig +} + +type EchoConfig struct { + Debug bool + ListenAddress string + EnableCORSMiddleware bool + EnableLoggerMiddleware bool + EnableRecoverMiddleware bool + EnableRequestIDMiddleware bool + EnableTrailingSlashMiddleware bool +} + +type LoggerConfig struct { + Level zerolog.Level + RequestLevel zerolog.Level + LogRequestBody bool + LogRequestHeader bool + LogRequestQuery bool + LogResponseBody bool + LogResponseHeader bool + PrettyPrintConsole bool } func DefaultServerConfigFromEnv() ServerConfig { return ServerConfig{ Address: util.GetEnv("INTEGRESQL_ADDRESS", ""), Port: util.GetEnvAsInt("INTEGRESQL_PORT", 5000), - DebugEndpoints: util.GetEnvAsBool("INTEGRESQL_DEBUG_ENDPOINTS", true), + DebugEndpoints: util.GetEnvAsBool("INTEGRESQL_DEBUG_ENDPOINTS", true), // https://golang.org/pkg/net/http/pprof/ + Echo: EchoConfig{ + Debug: util.GetEnvAsBool("INTEGRESQL_ECHO_DEBUG", false), + EnableCORSMiddleware: util.GetEnvAsBool("INTEGRESQL_ECHO_ENABLE_CORS_MIDDLEWARE", true), + EnableLoggerMiddleware: util.GetEnvAsBool("INTEGRESQL_ECHO_ENABLE_LOGGER_MIDDLEWARE", true), + EnableRecoverMiddleware: util.GetEnvAsBool("INTEGRESQL_ECHO_ENABLE_RECOVER_MIDDLEWARE", true), + EnableRequestIDMiddleware: util.GetEnvAsBool("INTEGRESQL_ECHO_ENABLE_REQUEST_ID_MIDDLEWARE", true), + EnableTrailingSlashMiddleware: util.GetEnvAsBool("INTEGRESQL_ECHO_ENABLE_TRAILING_SLASH_MIDDLEWARE", true), + }, + Logger: LoggerConfig{ + Level: util.LogLevelFromString(util.GetEnv("INTEGRESQL_LOGGER_LEVEL", zerolog.InfoLevel.String())), + RequestLevel: util.LogLevelFromString(util.GetEnv("INTEGRESQL_LOGGER_REQUEST_LEVEL", zerolog.DebugLevel.String())), + LogRequestBody: util.GetEnvAsBool("INTEGRESQL_LOGGER_LOG_REQUEST_BODY", false), + LogRequestHeader: util.GetEnvAsBool("INTEGRESQL_LOGGER_LOG_REQUEST_HEADER", false), + LogRequestQuery: util.GetEnvAsBool("INTEGRESQL_LOGGER_LOG_REQUEST_QUERY", false), + LogResponseBody: util.GetEnvAsBool("INTEGRESQL_LOGGER_LOG_RESPONSE_BODY", false), + LogResponseHeader: util.GetEnvAsBool("INTEGRESQL_LOGGER_LOG_RESPONSE_HEADER", false), + PrettyPrintConsole: util.GetEnvAsBool("INTEGRESQL_LOGGER_PRETTY_PRINT_CONSOLE", false), + }, } } diff --git a/internal/api/templates/templates.go b/internal/api/templates/templates.go index f32446a..a81d1db 100644 --- a/internal/api/templates/templates.go +++ b/internal/api/templates/templates.go @@ -2,6 +2,7 @@ package templates import ( "context" + "errors" "net/http" "strconv" "time" @@ -33,14 +34,14 @@ func postInitializeTemplate(s *api.Server) echo.HandlerFunc { template, err := s.Manager.InitializeTemplateDatabase(ctx, payload.Hash) if err != nil { - switch err { - case manager.ErrManagerNotReady: + if errors.Is(err, manager.ErrManagerNotReady) { return echo.ErrServiceUnavailable - case manager.ErrTemplateAlreadyInitialized: + } else if errors.Is(err, manager.ErrTemplateAlreadyInitialized) { return echo.NewHTTPError(http.StatusLocked, "template is already initialized") - default: - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + + // default 500 + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, &template) @@ -55,17 +56,17 @@ func putFinalizeTemplate(s *api.Server) echo.HandlerFunc { defer cancel() if _, err := s.Manager.FinalizeTemplateDatabase(ctx, hash); err != nil { - switch err { - case manager.ErrTemplateAlreadyInitialized: + if errors.Is(err, manager.ErrTemplateAlreadyInitialized) { // template is initialized, we ignore this error return c.NoContent(http.StatusNoContent) - case manager.ErrManagerNotReady: + } else if errors.Is(err, manager.ErrManagerNotReady) { return echo.ErrServiceUnavailable - case manager.ErrTemplateNotFound: + } else if errors.Is(err, manager.ErrTemplateNotFound) { return echo.NewHTTPError(http.StatusNotFound, "template not found") - default: - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + + // default 500 + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.NoContent(http.StatusNoContent) @@ -80,14 +81,14 @@ func deleteDiscardTemplate(s *api.Server) echo.HandlerFunc { defer cancel() if err := s.Manager.DiscardTemplateDatabase(ctx, hash); err != nil { - switch err { - case manager.ErrManagerNotReady: + if errors.Is(err, manager.ErrManagerNotReady) { return echo.ErrServiceUnavailable - case manager.ErrTemplateNotFound: + } else if errors.Is(err, manager.ErrTemplateNotFound) { return echo.NewHTTPError(http.StatusNotFound, "template not found") - default: - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + + // default 500 + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.NoContent(http.StatusNoContent) @@ -103,24 +104,24 @@ func getTestDatabase(s *api.Server) echo.HandlerFunc { test, err := s.Manager.GetTestDatabase(ctx, hash) if err != nil { - switch err { - case manager.ErrManagerNotReady: + + if errors.Is(err, manager.ErrManagerNotReady) { return echo.ErrServiceUnavailable - case manager.ErrTemplateNotFound: + } else if errors.Is(err, manager.ErrTemplateNotFound) { return echo.NewHTTPError(http.StatusNotFound, "template not found") - case manager.ErrTemplateDiscarded: + } else if errors.Is(err, manager.ErrTemplateDiscarded) { return echo.NewHTTPError(http.StatusGone, "template was just discarded") - case pool.ErrPoolFull: - return echo.NewHTTPError(http.StatusInsufficientStorage, "pool is full and can't be extended") - default: - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + + // default 500 + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, &test) } } +// deprecated func deleteReturnTestDatabase(s *api.Server) echo.HandlerFunc { return postUnlockTestDatabase(s) } @@ -137,18 +138,18 @@ func postUnlockTestDatabase(s *api.Server) echo.HandlerFunc { defer cancel() if err := s.Manager.ReturnTestDatabase(ctx, hash, id); err != nil { - switch err { - case manager.ErrManagerNotReady: + if errors.Is(err, manager.ErrManagerNotReady) { return echo.ErrServiceUnavailable - case manager.ErrTemplateNotFound: + } else if errors.Is(err, manager.ErrTemplateNotFound) { return echo.NewHTTPError(http.StatusNotFound, "template not found") - case manager.ErrTestNotFound: + } else if errors.Is(err, manager.ErrTestNotFound) { return echo.NewHTTPError(http.StatusNotFound, "test database not found") - case pool.ErrTestDBInUse: + } else if errors.Is(err, pool.ErrTestDBInUse) { return echo.NewHTTPError(http.StatusLocked, pool.ErrTestDBInUse.Error()) - default: - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + + // default 500 + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.NoContent(http.StatusNoContent) @@ -164,18 +165,19 @@ func postRecreateTestDatabase(s *api.Server) echo.HandlerFunc { } if err := s.Manager.RecreateTestDatabase(c.Request().Context(), hash, id); err != nil { - switch err { - case manager.ErrManagerNotReady: + + if errors.Is(err, manager.ErrManagerNotReady) { return echo.ErrServiceUnavailable - case manager.ErrTemplateNotFound: + } else if errors.Is(err, manager.ErrTemplateNotFound) { return echo.NewHTTPError(http.StatusNotFound, "template not found") - case manager.ErrTestNotFound: + } else if errors.Is(err, manager.ErrTestNotFound) { return echo.NewHTTPError(http.StatusNotFound, "test database not found") - case pool.ErrTestDBInUse: + } else if errors.Is(err, pool.ErrTestDBInUse) { return echo.NewHTTPError(http.StatusLocked, pool.ErrTestDBInUse.Error()) - default: - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + + // default 500 + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.NoContent(http.StatusNoContent) diff --git a/internal/config/build_args.go b/internal/config/build_args.go new file mode 100644 index 0000000..7972bc2 --- /dev/null +++ b/internal/config/build_args.go @@ -0,0 +1,18 @@ +package config + +import "fmt" + +// The following vars are automatically injected via -ldflags. +// See Makefile target "make go-build" and make var $(LDFLAGS). +// No need to change them here. +// https://www.digitalocean.com/community/tutorials/using-ldflags-to-set-version-information-for-go-applications +var ( + ModuleName = "build.local/misses/ldflags" // e.g. "allaboutapps.dev/aw/go-starter" + Commit = "< 40 chars git commit hash via ldflags >" // e.g. "59cb7684dd0b0f38d68cd7db657cb614feba8f7e" + BuildDate = "1970-01-01T00:00:00+00:00" // e.g. "1970-01-01T00:00:00+00:00" +) + +// GetFormattedBuildArgs returns string representation of buildsargs set via ldflags " @ ()" +func GetFormattedBuildArgs() string { + return fmt.Sprintf("%v @ %v (%v)", ModuleName, Commit, BuildDate) +} diff --git a/internal/router/echo_logger.go b/internal/router/echo_logger.go new file mode 100644 index 0000000..0931855 --- /dev/null +++ b/internal/router/echo_logger.go @@ -0,0 +1,13 @@ +package router + +import "github.com/rs/zerolog" + +type echoLogger struct { + level zerolog.Level + log zerolog.Logger +} + +func (l *echoLogger) Write(p []byte) (n int, err error) { + l.log.WithLevel(l.level).Msgf("%s", p) + return len(p), nil +} diff --git a/internal/router/router.go b/internal/router/router.go index 35d5505..4fac61d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -5,28 +5,67 @@ import ( "github.com/allaboutapps/integresql/internal/api" "github.com/allaboutapps/integresql/internal/api/admin" + "github.com/allaboutapps/integresql/internal/api/middleware" "github.com/allaboutapps/integresql/internal/api/templates" "github.com/labstack/echo/v4" echoMiddleware "github.com/labstack/echo/v4/middleware" + "github.com/rs/zerolog/log" ) func Init(s *api.Server) { s.Echo = echo.New() - s.Echo.Debug = false + s.Echo.Debug = s.Config.Echo.Debug s.Echo.HideBanner = true + s.Echo.Logger.SetOutput(&echoLogger{level: s.Config.Logger.RequestLevel, log: log.With().Str("component", "echo").Logger()}) - s.Echo.Pre(echoMiddleware.RemoveTrailingSlash()) + // --- + // General middleware + if s.Config.Echo.EnableTrailingSlashMiddleware { + s.Echo.Pre(echoMiddleware.RemoveTrailingSlash()) + } else { + log.Warn().Msg("Disabling trailing slash middleware due to environment config") + } - s.Echo.Use(echoMiddleware.Recover()) - s.Echo.Use(echoMiddleware.RequestID()) - s.Echo.Use(echoMiddleware.Logger()) + if s.Config.Echo.EnableRecoverMiddleware { + s.Echo.Use(echoMiddleware.Recover()) + } else { + log.Warn().Msg("Disabling recover middleware due to environment config") + } - admin.InitRoutes(s) - templates.InitRoutes(s) + if s.Config.Echo.EnableRequestIDMiddleware { + s.Echo.Use(echoMiddleware.RequestID()) + } else { + log.Warn().Msg("Disabling request ID middleware due to environment config") + } + + if s.Config.Echo.EnableLoggerMiddleware { + s.Echo.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ + Level: s.Config.Logger.RequestLevel, + LogRequestBody: s.Config.Logger.LogRequestBody, + LogRequestHeader: s.Config.Logger.LogRequestHeader, + LogRequestQuery: s.Config.Logger.LogRequestQuery, + LogResponseBody: s.Config.Logger.LogResponseBody, + LogResponseHeader: s.Config.Logger.LogResponseHeader, + RequestBodyLogSkipper: func(req *http.Request) bool { + return middleware.DefaultRequestBodyLogSkipper(req) + }, + ResponseBodyLogSkipper: func(req *http.Request, res *echo.Response) bool { + return middleware.DefaultResponseBodyLogSkipper(req, res) + }, + Skipper: func(c echo.Context) bool { + return false + }, + })) + } else { + log.Warn().Msg("Disabling logger middleware due to environment config") + } // enable debug endpoints only if requested if s.Config.DebugEndpoints { s.Echo.GET("/debug/*", echo.WrapHandler(http.DefaultServeMux)) } + + admin.InitRoutes(s) + templates.InitRoutes(s) } diff --git a/internal/router/router_test.go b/internal/router/router_test.go new file mode 100644 index 0000000..c393096 --- /dev/null +++ b/internal/router/router_test.go @@ -0,0 +1,35 @@ +package router_test + +import ( + "testing" + + "github.com/allaboutapps/integresql/internal/api" + "github.com/allaboutapps/integresql/internal/test" + "github.com/stretchr/testify/require" +) + +func TestPprofEnabledNoAuth(t *testing.T) { + config := api.DefaultServerConfigFromEnv() + + // these are typically our default values, however we force set them here to ensure those are set while test execution. + config.DebugEndpoints = true + + test.WithTestServerConfigurable(t, config, func(s *api.Server) { + res := test.PerformRequest(t, s, "GET", "/debug/pprof/heap/", nil, nil) + require.Equal(t, 200, res.Result().StatusCode) + + // index + res = test.PerformRequest(t, s, "GET", "/debug/pprof/", nil, nil) + require.Equal(t, 301, res.Result().StatusCode) + }) +} + +func TestPprofDisabled(t *testing.T) { + config := api.DefaultServerConfigFromEnv() + config.DebugEndpoints = false + + test.WithTestServerConfigurable(t, config, func(s *api.Server) { + res := test.PerformRequest(t, s, "GET", "/debug/pprof/heap", nil, nil) + require.Equal(t, 404, res.Result().StatusCode) + }) +} diff --git a/internal/test/helper_request.go b/internal/test/helper_request.go new file mode 100644 index 0000000..7693358 --- /dev/null +++ b/internal/test/helper_request.go @@ -0,0 +1,132 @@ +package test + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/allaboutapps/integresql/internal/api" + "github.com/labstack/echo/v4" +) + +type GenericPayload map[string]interface{} +type GenericArrayPayload []interface{} + +func (g GenericPayload) Reader(t *testing.T) *bytes.Reader { + t.Helper() + + b, err := json.Marshal(g) + if err != nil { + t.Fatalf("failed to serialize payload: %v", err) + } + + return bytes.NewReader(b) +} + +func (g GenericArrayPayload) Reader(t *testing.T) *bytes.Reader { + t.Helper() + + b, err := json.Marshal(g) + if err != nil { + t.Fatalf("failed to serialize payload: %v", err) + } + + return bytes.NewReader(b) +} + +func PerformRequestWithParams(t *testing.T, s *api.Server, method string, path string, body GenericPayload, headers http.Header, queryParams map[string]string) *httptest.ResponseRecorder { + t.Helper() + + if body == nil { + return PerformRequestWithRawBody(t, s, method, path, nil, headers, queryParams) + } + + return PerformRequestWithRawBody(t, s, method, path, body.Reader(t), headers, queryParams) +} + +func PerformRequestWithArrayAndParams(t *testing.T, s *api.Server, method string, path string, body GenericArrayPayload, headers http.Header, queryParams map[string]string) *httptest.ResponseRecorder { + t.Helper() + + if body == nil { + return PerformRequestWithRawBody(t, s, method, path, nil, headers, queryParams) + } + + return PerformRequestWithRawBody(t, s, method, path, body.Reader(t), headers, queryParams) +} + +func PerformRequestWithRawBody(t *testing.T, s *api.Server, method string, path string, body io.Reader, headers http.Header, queryParams map[string]string) *httptest.ResponseRecorder { + t.Helper() + + req := httptest.NewRequest(method, path, body) + + if headers != nil { + req.Header = headers + } + if body != nil && len(req.Header.Get(echo.HeaderContentType)) == 0 { + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + } + + if queryParams != nil { + q := req.URL.Query() + for k, v := range queryParams { + q.Add(k, v) + } + + req.URL.RawQuery = q.Encode() + } + + res := httptest.NewRecorder() + + s.Echo.ServeHTTP(res, req) + + return res +} + +func PerformRequest(t *testing.T, s *api.Server, method string, path string, body GenericPayload, headers http.Header) *httptest.ResponseRecorder { + t.Helper() + + return PerformRequestWithParams(t, s, method, path, body, headers, nil) +} + +func PerformRequestWithArray(t *testing.T, s *api.Server, method string, path string, body GenericArrayPayload, headers http.Header) *httptest.ResponseRecorder { + t.Helper() + + return PerformRequestWithArrayAndParams(t, s, method, path, body, headers, nil) +} + +func ParseResponseBody(t *testing.T, res *httptest.ResponseRecorder, v interface{}) { + t.Helper() + + if err := json.NewDecoder(res.Result().Body).Decode(&v); err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } +} + +// func ParseResponseAndValidate(t *testing.T, res *httptest.ResponseRecorder, v runtime.Validatable) { +// t.Helper() + +// ParseResponseBody(t, res, &v) + +// if err := v.Validate(strfmt.Default); err != nil { +// t.Fatalf("Failed to validate response: %v", err) +// } +// } + +func HeadersWithAuth(t *testing.T, token string) http.Header { + t.Helper() + + return HeadersWithConfigurableAuth(t, "Bearer", token) +} + +func HeadersWithConfigurableAuth(t *testing.T, scheme string, token string) http.Header { + t.Helper() + + headers := http.Header{} + headers.Set(echo.HeaderAuthorization, fmt.Sprintf("%s %s", scheme, token)) + + return headers +} diff --git a/internal/test/test_server.go b/internal/test/test_server.go new file mode 100644 index 0000000..45c108c --- /dev/null +++ b/internal/test/test_server.go @@ -0,0 +1,62 @@ +package test + +import ( + "context" + "testing" + + "github.com/allaboutapps/integresql/internal/api" + "github.com/allaboutapps/integresql/internal/router" +) + +// WithTestServer returns a fully configured server (using the default server config). +func WithTestServer(t *testing.T, closure func(s *api.Server)) { + t.Helper() + defaultConfig := api.DefaultServerConfigFromEnv() + WithTestServerConfigurable(t, defaultConfig, closure) +} + +// WithTestServerConfigurable returns a fully configured server, allowing for configuration using the provided server config. +func WithTestServerConfigurable(t *testing.T, config api.ServerConfig, closure func(s *api.Server)) { + t.Helper() + ctx := context.Background() + WithTestServerConfigurableContext(ctx, t, config, closure) +} + +// WithTestServerConfigurableContext returns a fully configured server, allowing for configuration using the provided server config. +// The provided context will be used during setup (instead of the default background context). +func WithTestServerConfigurableContext(ctx context.Context, t *testing.T, config api.ServerConfig, closure func(s *api.Server)) { + t.Helper() + execClosureNewTestServer(ctx, t, config, closure) + +} + +// Executes closure on a new test server +func execClosureNewTestServer(ctx context.Context, t *testing.T, config api.ServerConfig, closure func(s *api.Server)) { + t.Helper() + + // https://stackoverflow.com/questions/43424787/how-to-use-next-available-port-in-http-listenandserve + // You may use port 0 to indicate you're not specifying an exact port but you want a free, available port selected by the system + config.Address = ":0" + + s := api.NewServer(config) + + if err := s.InitManager(ctx); err != nil { + t.Fatalf("failed to start manager: %v", err) + } + + router.Init(s) + + closure(s) + + // echo is managed and should close automatically after running the test + if err := s.Echo.Shutdown(ctx); err != nil { + t.Fatalf("failed to shutdown server: %v", err) + } + + if err := s.Manager.Disconnect(ctx, true); err != nil { + t.Fatalf("failed to shutdown manager: %v", err) + } + + // disallow any further refs to managed object after running the test + s = nil +} diff --git a/pkg/db/database_config_test.go b/pkg/db/database_config_internal_test.go similarity index 100% rename from pkg/db/database_config_test.go rename to pkg/db/database_config_internal_test.go diff --git a/pkg/manager/helpers_test.go b/pkg/manager/helpers_test.go index 27c5048..00baa06 100644 --- a/pkg/manager/helpers_test.go +++ b/pkg/manager/helpers_test.go @@ -51,7 +51,7 @@ func disconnectManager(t *testing.T, m *manager.Manager) { } -func initTemplateDB(ctx context.Context, errs chan<- error, m *manager.Manager) { +func initTemplateDB(_ context.Context, errs chan<- error, m *manager.Manager) { template, err := m.InitializeTemplateDatabase(context.Background(), "hashinghash") if err != nil { @@ -159,7 +159,7 @@ func verifyTestDB(t *testing.T, test db.TestDatabase) { } } -func getTestDB(ctx context.Context, errs chan<- error, m *manager.Manager) { +func getTestDB(_ context.Context, errs chan<- error, m *manager.Manager) { _, err := m.GetTestDatabase(context.Background(), "hashinghash") errs <- err diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index 72746a0..67f13f0 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -3,6 +3,7 @@ package manager import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "runtime/trace" @@ -11,7 +12,10 @@ import ( "github.com/allaboutapps/integresql/pkg/db" "github.com/allaboutapps/integresql/pkg/pool" "github.com/allaboutapps/integresql/pkg/templates" + "github.com/allaboutapps/integresql/pkg/util" "github.com/lib/pq" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) var ( @@ -37,10 +41,12 @@ func New(config ManagerConfig) (*Manager, ManagerConfig) { if config.DatabasePrefix != "" { testDBPrefix = testDBPrefix + fmt.Sprintf("%s_", config.DatabasePrefix) } - if config.TestDatabasePrefix != "" { - testDBPrefix = testDBPrefix + fmt.Sprintf("%s_", config.TestDatabasePrefix) + if config.PoolConfig.TestDBNamePrefix != "" { + testDBPrefix = testDBPrefix + fmt.Sprintf("%s_", config.PoolConfig.TestDBNamePrefix) } + config.PoolConfig.TestDBNamePrefix = testDBPrefix + if len(config.TestDatabaseOwner) == 0 { config.TestDatabaseOwner = config.ManagerDatabaseConfig.Username } @@ -50,30 +56,32 @@ func New(config ManagerConfig) (*Manager, ManagerConfig) { } // at least one test database needs to be present initially - if config.TestDatabaseInitialPoolSize == 0 { - config.TestDatabaseInitialPoolSize = 1 + if config.PoolConfig.InitialPoolSize == 0 { + config.PoolConfig.InitialPoolSize = 1 + } + + if config.PoolConfig.InitialPoolSize > config.PoolConfig.MaxPoolSize && config.PoolConfig.MaxPoolSize > 0 { + config.PoolConfig.InitialPoolSize = config.PoolConfig.MaxPoolSize } - if config.TestDatabaseInitialPoolSize > config.TestDatabaseMaxPoolSize && config.TestDatabaseMaxPoolSize > 0 { - config.TestDatabaseInitialPoolSize = config.TestDatabaseMaxPoolSize + if config.PoolConfig.MaxParallelTasks < 1 { + config.PoolConfig.MaxParallelTasks = 1 } - if config.PoolMaxParallelTasks < 1 { - config.PoolMaxParallelTasks = 1 + // debug log final derived config + c, err := json.Marshal(config) + + if err != nil { + log.Fatal().Err(err).Msg("Failed to marshal the env") } + log.Debug().RawJSON("config", c).Msg("manager.New") + m := &Manager{ config: config, db: nil, templates: templates.NewCollection(), - pool: pool.NewPoolCollection( - pool.PoolConfig{ - MaxPoolSize: config.TestDatabaseMaxPoolSize, - InitialPoolSize: config.TestDatabaseInitialPoolSize, - TestDBNamePrefix: testDBPrefix, - PoolMaxParallelTasks: config.PoolMaxParallelTasks, - }, - ), + pool: pool.NewPoolCollection(config.PoolConfig), } return m, m.config @@ -85,38 +93,55 @@ func DefaultFromEnv() *Manager { } func (m *Manager) Connect(ctx context.Context) error { + + log := m.getManagerLogger(ctx, "Connect") + if m.db != nil { - return errors.New("manager is already connected") + err := errors.New("manager is already connected") + log.Error().Err(err) + return err } db, err := sql.Open("postgres", m.config.ManagerDatabaseConfig.ConnectionString()) if err != nil { + log.Error().Err(err).Msg("unable to connect") return err } if err := db.PingContext(ctx); err != nil { + log.Error().Err(err).Msg("unable to ping") return err } m.db = db + log.Debug().Msg("connected.") + return nil } func (m *Manager) Disconnect(ctx context.Context, ignoreCloseError bool) error { + + log := m.getManagerLogger(ctx, "Disconnect").With().Bool("ignoreCloseError", ignoreCloseError).Logger() + if m.db == nil { - return errors.New("manager is not connected") + err := errors.New("manager is not connected") + log.Error().Err(err) + return err } // stop the pool before closing DB connection m.pool.Stop() if err := m.db.Close(); err != nil && !ignoreCloseError { + log.Error().Err(err) return err } m.db = nil + log.Warn().Msg("disconnected.") + return nil } @@ -133,37 +158,53 @@ func (m Manager) Ready() bool { } func (m *Manager) Initialize(ctx context.Context) error { + + log := m.getManagerLogger(ctx, "Initialize") + if !m.Ready() { if err := m.Connect(ctx); err != nil { + log.Error().Err(err) return err } } - rows, err := m.db.QueryContext(ctx, "SELECT datname FROM pg_database WHERE datname LIKE $1", fmt.Sprintf("%s_%s_%%", m.config.DatabasePrefix, m.config.TestDatabasePrefix)) + rows, err := m.db.QueryContext(ctx, "SELECT datname FROM pg_database WHERE datname LIKE $1", fmt.Sprintf("%s_%s_%%", m.config.DatabasePrefix, m.config.PoolConfig.TestDBNamePrefix)) if err != nil { + log.Error().Err(err) return err } defer rows.Close() + log.Debug().Msg("Dropping unmanaged dbs...") + for rows.Next() { var dbName string if err := rows.Scan(&dbName); err != nil { return err } + log.Warn().Str("dbName", dbName).Msg("Dropping...") + if _, err := m.db.Exec(fmt.Sprintf("DROP DATABASE %s", pq.QuoteIdentifier(dbName))); err != nil { + log.Error().Str("dbName", dbName).Err(err) return err } } + log.Info().Msg("initialized.") + return nil } func (m Manager) InitializeTemplateDatabase(ctx context.Context, hash string) (db.TemplateDatabase, error) { ctx, task := trace.NewTask(ctx, "initialize_template_db") + + log := m.getManagerLogger(ctx, "InitializeTemplateDatabase").With().Str("hash", hash).Logger() + defer task.End() if !m.Ready() { + log.Error().Msg("not ready") return db.TemplateDatabase{}, ErrManagerNotReady } @@ -188,6 +229,8 @@ func (m Manager) InitializeTemplateDatabase(ctx context.Context, hash string) (d reg := trace.StartRegion(ctx, "drop_and_create_db") if err := m.dropAndCreateDatabase(ctx, dbName, m.config.ManagerDatabaseConfig.Username, m.config.TemplateDatabaseTemplate); err != nil { + + log.Error().Err(err).Msg("triggering unsafe remove after dropAndCreateDatabase failed...") m.templates.RemoveUnsafe(ctx, hash) return db.TemplateDatabase{}, err @@ -197,6 +240,8 @@ func (m Manager) InitializeTemplateDatabase(ctx context.Context, hash string) (d // if template config has been overwritten, the existing pool needs to be removed err := m.pool.RemoveAllWithHash(ctx, hash, m.dropTestPoolDB) if err != nil && !errors.Is(err, pool.ErrUnknownHash) { + + log.Error().Err(err).Msg("triggering unsafe remove after RemoveAllWithHash failed...") m.templates.RemoveUnsafe(ctx, hash) return db.TemplateDatabase{}, err @@ -213,14 +258,18 @@ func (m Manager) InitializeTemplateDatabase(ctx context.Context, hash string) (d func (m Manager) DiscardTemplateDatabase(ctx context.Context, hash string) error { ctx, task := trace.NewTask(ctx, "discard_template_db") + log := m.getManagerLogger(ctx, "DiscardTemplateDatabase").With().Str("hash", hash).Logger() + defer task.End() if !m.Ready() { + log.Error().Msg("not ready") return ErrManagerNotReady } // first remove all DB with this hash if err := m.pool.RemoveAllWithHash(ctx, hash, m.dropTestPoolDB); err != nil && !errors.Is(err, pool.ErrUnknownHash) { + log.Error().Err(err).Msg("remove all err") return err } @@ -229,6 +278,9 @@ func (m Manager) DiscardTemplateDatabase(ctx context.Context, hash string) error if !found { // even if a template is not found in the collection, it might still exist in the DB + + log.Warn().Msg("template not found, checking for existance...") + dbName = m.makeTemplateDatabaseName(hash) exists, err := m.checkDatabaseExists(ctx, dbName) if err != nil { @@ -242,19 +294,26 @@ func (m Manager) DiscardTemplateDatabase(ctx context.Context, hash string) error template.SetState(ctx, templates.TemplateStateDiscarded) } + log.Debug().Msg("found template database, dropping...") + return m.dropDatabase(ctx, dbName) } func (m Manager) FinalizeTemplateDatabase(ctx context.Context, hash string) (db.TemplateDatabase, error) { ctx, task := trace.NewTask(ctx, "finalize_template_db") + + log := m.getManagerLogger(ctx, "FinalizeTemplateDatabase").With().Str("hash", hash).Logger() + defer task.End() if !m.Ready() { + log.Error().Msg("not ready") return db.TemplateDatabase{}, ErrManagerNotReady } template, found := m.templates.Get(ctx, hash) if !found { + log.Error().Msg("bailout: template not found") return db.TemplateDatabase{}, ErrTemplateNotFound } @@ -263,29 +322,36 @@ func (m Manager) FinalizeTemplateDatabase(ctx context.Context, hash string) (db. // early bailout if we are already ready (multiple calls) if state == templates.TemplateStateFinalized { + log.Warn().Msg("bailout: template already finalized") return db.TemplateDatabase{Database: template.Database}, ErrTemplateAlreadyInitialized } // Disallow transition from discarded to ready if state == templates.TemplateStateDiscarded { + log.Error().Msg("bailout: template discarded!") return db.TemplateDatabase{}, ErrTemplateDiscarded } // Init a pool with this hash + log.Trace().Msg("init hash pool...") m.pool.InitHashPool(ctx, template.Database, m.recreateTestPoolDB) lockedTemplate.SetState(ctx, templates.TemplateStateFinalized) + log.Debug().Msg("Template database finalized successfully.") return db.TemplateDatabase{Database: template.Database}, nil } // GetTestDatabase tries to get a ready test DB from an existing pool. -// If no DB is ready after the preconfigured timeout, ErrTimeout is returned. func (m Manager) GetTestDatabase(ctx context.Context, hash string) (db.TestDatabase, error) { ctx, task := trace.NewTask(ctx, "get_test_db") + + log := m.getManagerLogger(ctx, "GetTestDatabase").With().Str("hash", hash).Logger() + defer task.End() if !m.Ready() { + log.Error().Msg("not ready") return db.TestDatabase{}, ErrManagerNotReady } @@ -308,6 +374,7 @@ func (m Manager) GetTestDatabase(ctx context.Context, hash string) (db.TestDatab // Template exists, but the pool is not there - // it must have been removed. // It needs to be reinitialized. + log.Warn().Err(err).Msg("ErrUnknownHash, going to InitHashPool and recursively calling us again...") m.pool.InitHashPool(ctx, template.Database, m.recreateTestPoolDB) testDB, err = m.pool.GetTestDatabase(ctx, template.TemplateHash, m.config.TestDatabaseGetTimeout) @@ -332,7 +399,7 @@ func (m Manager) ReturnTestDatabase(ctx context.Context, hash string, id int) er // check if the template exists and is finalized template, found := m.templates.Get(ctx, hash) if !found { - return m.dropDatabaseWithID(ctx, hash, id) + return ErrTemplateNotFound } if template.WaitUntilFinalized(ctx, m.config.TemplateFinalizeTimeout) != @@ -342,19 +409,7 @@ func (m Manager) ReturnTestDatabase(ctx context.Context, hash string, id int) er } // template is ready, we can return unchanged testDB to the pool - if err := m.pool.ReturnTestDatabase(ctx, hash, id); err != nil { - if !(errors.Is(err, pool.ErrInvalidIndex) || - errors.Is(err, pool.ErrUnknownHash)) { - // other error is an internal error - return err - } - - // db is not tracked in the pool - // try to drop it if exists - return m.dropDatabaseWithID(ctx, hash, id) - } - - return nil + return m.pool.ReturnTestDatabase(ctx, hash, id) } // RecreateTestDatabase recreates the test DB according to the template and returns it back to the pool. @@ -369,7 +424,7 @@ func (m *Manager) RecreateTestDatabase(ctx context.Context, hash string, id int) // check if the template exists and is finalized template, found := m.templates.Get(ctx, hash) if !found { - return m.dropDatabaseWithID(ctx, hash, id) + return ErrTemplateNotFound } if template.WaitUntilFinalized(ctx, m.config.TemplateFinalizeTimeout) != @@ -377,27 +432,21 @@ func (m *Manager) RecreateTestDatabase(ctx context.Context, hash string, id int) return ErrInvalidTemplateState } - // template is ready, we can returb the testDB to the pool and have it cleaned up - if err := m.pool.RecreateTestDatabase(ctx, hash, id); err != nil { - if !(errors.Is(err, pool.ErrInvalidIndex) || - errors.Is(err, pool.ErrUnknownHash)) { - // other error is an internal error - return err - } - - // db is not tracked in the pool - // try to drop it if exists - return m.dropDatabaseWithID(ctx, hash, id) - } - - return nil + // template is ready, we can return the testDB to the pool and have it cleaned up + return m.pool.RecreateTestDatabase(ctx, hash, id) } func (m Manager) ClearTrackedTestDatabases(ctx context.Context, hash string) error { + + log := m.getManagerLogger(ctx, "ClearTrackedTestDatabases").With().Str("hash", hash).Logger() + if !m.Ready() { + log.Error().Msg("not ready") return ErrManagerNotReady } + log.Warn().Msg("clearing...") + err := m.pool.RemoveAllWithHash(ctx, hash, m.dropTestPoolDB) if errors.Is(err, pool.ErrUnknownHash) { return ErrTemplateNotFound @@ -407,38 +456,27 @@ func (m Manager) ClearTrackedTestDatabases(ctx context.Context, hash string) err } func (m Manager) ResetAllTracking(ctx context.Context) error { + + log := m.getManagerLogger(ctx, "ResetAllTracking") + if !m.Ready() { + log.Error().Msg("not ready") return ErrManagerNotReady } + log.Warn().Msg("resetting...") + // remove all templates to disallow any new test DB creation from existing templates m.templates.RemoveAll(ctx) - if err := m.pool.RemoveAll(ctx, m.dropTestPoolDB); err != nil { - return err - } - - return nil -} - -func (m Manager) dropDatabaseWithID(ctx context.Context, hash string, id int) error { - dbName := m.pool.MakeDBName(hash, id) - exists, err := m.checkDatabaseExists(ctx, dbName) - if err != nil { - return err - } - - if !exists { - return ErrTestNotFound - } - - return m.dropDatabase(ctx, dbName) + return m.pool.RemoveAll(ctx, m.dropTestPoolDB) } func (m Manager) checkDatabaseExists(ctx context.Context, dbName string) (bool, error) { var exists bool - // fmt.Printf("SELECT 1 AS exists FROM pg_database WHERE datname = %s\n", dbName) + log := m.getManagerLogger(ctx, "checkDatabaseExists") + log.Trace().Msgf("SELECT 1 AS exists FROM pg_database WHERE datname = %s\n", dbName) if err := m.db.QueryRowContext(ctx, "SELECT 1 AS exists FROM pg_database WHERE datname = $1", dbName).Scan(&exists); err != nil { if err == sql.ErrNoRows { @@ -474,7 +512,8 @@ func (m Manager) createDatabase(ctx context.Context, dbName string, owner string defer trace.StartRegion(ctx, "create_db").End() - // fmt.Printf("CREATE DATABASE %s WITH OWNER %s TEMPLATE %s\n", pq.QuoteIdentifier(dbName), pq.QuoteIdentifier(owner), pq.QuoteIdentifier(template)) + log := m.getManagerLogger(ctx, "createDatabase") + log.Trace().Msgf("CREATE DATABASE %s WITH OWNER %s TEMPLATE %s\n", pq.QuoteIdentifier(dbName), pq.QuoteIdentifier(owner), pq.QuoteIdentifier(template)) if _, err := m.db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s WITH OWNER %s TEMPLATE %s", pq.QuoteIdentifier(dbName), pq.QuoteIdentifier(owner), pq.QuoteIdentifier(template))); err != nil { return err @@ -506,7 +545,8 @@ func (m Manager) dropDatabase(ctx context.Context, dbName string) error { defer trace.StartRegion(ctx, "drop_db").End() - // fmt.Printf("DROP DATABASE IF EXISTS %s\n", pq.QuoteIdentifier(dbName)) + log := m.getManagerLogger(ctx, "dropDatabase") + log.Trace().Msgf("DROP DATABASE IF EXISTS %s\n", pq.QuoteIdentifier(dbName)) if _, err := m.db.ExecContext(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s", pq.QuoteIdentifier(dbName))); err != nil { if strings.Contains(err.Error(), "is being accessed by other users") { @@ -534,3 +574,7 @@ func (m Manager) dropAndCreateDatabase(ctx context.Context, dbName string, owner func (m Manager) makeTemplateDatabaseName(hash string) string { return fmt.Sprintf("%s_%s_%s", m.config.DatabasePrefix, m.config.TemplateDatabasePrefix, hash) } + +func (m Manager) getManagerLogger(ctx context.Context, managerFunction string) zerolog.Logger { + return util.LogFromContext(ctx).With().Str("managerFn", managerFunction).Logger() +} diff --git a/pkg/manager/manager_config.go b/pkg/manager/manager_config.go index 3269375..d59f000 100644 --- a/pkg/manager/manager_config.go +++ b/pkg/manager/manager_config.go @@ -5,23 +5,23 @@ import ( "time" "github.com/allaboutapps/integresql/pkg/db" + "github.com/allaboutapps/integresql/pkg/pool" "github.com/allaboutapps/integresql/pkg/util" ) -type ManagerConfig struct { - ManagerDatabaseConfig db.DatabaseConfig +// we explicitly want to access this struct via manager.ManagerConfig, thus we disable revive for the next line +type ManagerConfig struct { //nolint:revive + ManagerDatabaseConfig db.DatabaseConfig `json:"-"` // sensitive TemplateDatabaseTemplate string - DatabasePrefix string - TemplateDatabasePrefix string - TestDatabasePrefix string - TestDatabaseOwner string - TestDatabaseOwnerPassword string - TestDatabaseInitialPoolSize int // Initial number of ready DBs prepared in background - TestDatabaseMaxPoolSize int // Maximal pool size that won't be exceeded - TemplateFinalizeTimeout time.Duration // Time to wait for a template to transition into the 'finalized' state - TestDatabaseGetTimeout time.Duration // Time to wait for a ready database - PoolMaxParallelTasks int // Maximal number of pool tasks running in parallel. Must be a number greater or equal 1. + DatabasePrefix string + TemplateDatabasePrefix string + TestDatabaseOwner string + TestDatabaseOwnerPassword string `json:"-"` // sensitive + TemplateFinalizeTimeout time.Duration // Time to wait for a template to transition into the 'finalized' state + TestDatabaseGetTimeout time.Duration // Time to wait for a ready database + + PoolConfig pool.PoolConfig } func DefaultManagerConfigFromEnv() ManagerConfig { @@ -50,18 +50,20 @@ func DefaultManagerConfigFromEnv() ManagerConfig { // DatabasePrefix_TemplateDatabasePrefix_HASH TemplateDatabasePrefix: util.GetEnv("INTEGRESQL_TEMPLATE_DB_PREFIX", "template"), - // DatabasePrefix_TestDatabasePrefix_HASH_ID - TestDatabasePrefix: util.GetEnv("INTEGRESQL_TEST_DB_PREFIX", "test"), - - // reuse the same user (PGUSER) and passwort (PGPASSWORT) for the test / template databases by default + // we reuse the same user (PGUSER) and passwort (PGPASSWORT) for the test / template databases by default TestDatabaseOwner: util.GetEnv("INTEGRESQL_TEST_PGUSER", util.GetEnv("INTEGRESQL_PGUSER", util.GetEnv("PGUSER", "postgres"))), TestDatabaseOwnerPassword: util.GetEnv("INTEGRESQL_TEST_PGPASSWORD", util.GetEnv("INTEGRESQL_PGPASSWORD", util.GetEnv("PGPASSWORD", ""))), - // TestDatabaseInitialPoolSize: util.GetEnvAsInt("INTEGRESQL_TEST_INITIAL_POOL_SIZE", 10), - TestDatabaseInitialPoolSize: util.GetEnvAsInt("INTEGRESQL_TEST_INITIAL_POOL_SIZE", runtime.NumCPU()), - // TestDatabaseMaxPoolSize: util.GetEnvAsInt("INTEGRESQL_TEST_MAX_POOL_SIZE", 500), - TestDatabaseMaxPoolSize: util.GetEnvAsInt("INTEGRESQL_TEST_MAX_POOL_SIZE", runtime.NumCPU()*4), - TemplateFinalizeTimeout: time.Millisecond * time.Duration(util.GetEnvAsInt("INTEGRESQL_TEMPLATE_FINALIZE_TIMEOUT_MS", 5*60*10e3 /*5 min*/)), - TestDatabaseGetTimeout: time.Millisecond * time.Duration(util.GetEnvAsInt("INTEGRESQL_TEST_DB_GET_TIMEOUT_MS", 1*60*10e3 /*1 min, timeout hardcoded also in GET request handler*/)), - PoolMaxParallelTasks: util.GetEnvAsInt("INTEGRESQL_POOL_MAX_PARALLEL_TASKS", runtime.NumCPU()), + TemplateFinalizeTimeout: time.Millisecond * time.Duration(util.GetEnvAsInt("INTEGRESQL_TEMPLATE_FINALIZE_TIMEOUT_MS", 5*60*1000 /*5 min*/)), + TestDatabaseGetTimeout: time.Millisecond * time.Duration(util.GetEnvAsInt("INTEGRESQL_TEST_DB_GET_TIMEOUT_MS", 1*60*1000 /*1 min, timeout hardcoded also in GET request handler*/)), + + PoolConfig: pool.PoolConfig{ + InitialPoolSize: util.GetEnvAsInt("INTEGRESQL_TEST_INITIAL_POOL_SIZE", runtime.NumCPU()), // previously default 10 + MaxPoolSize: util.GetEnvAsInt("INTEGRESQL_TEST_MAX_POOL_SIZE", runtime.NumCPU()*4), // previously default 500 + TestDBNamePrefix: util.GetEnv("INTEGRESQL_TEST_DB_PREFIX", "test"), // DatabasePrefix_TestDBNamePrefix_HASH_ID + MaxParallelTasks: util.GetEnvAsInt("INTEGRESQL_POOL_MAX_PARALLEL_TASKS", runtime.NumCPU()), + TestDatabaseRetryRecreateSleepMin: time.Millisecond * time.Duration(util.GetEnvAsInt("INTEGRESQL_TEST_DB_RETRY_RECREATE_SLEEP_MIN_MS", 250 /*250 ms*/)), + TestDatabaseRetryRecreateSleepMax: time.Millisecond * time.Duration(util.GetEnvAsInt("INTEGRESQL_TEST_DB_RETRY_RECREATE_SLEEP_MAX_MS", 1000*3 /*3 sec*/)), + TestDatabaseMinimalLifetime: time.Millisecond * time.Duration(util.GetEnvAsInt("INTEGRESQL_TEST_DB_MINIMAL_LIFETIME_MS", 250 /*250 ms*/)), + }, } } diff --git a/pkg/manager/manager_test.go b/pkg/manager/manager_test.go index 8709cb8..8765fd2 100644 --- a/pkg/manager/manager_test.go +++ b/pkg/manager/manager_test.go @@ -129,7 +129,7 @@ func TestManagerInitializeTemplateDatabaseTimeout(t *testing.T) { defer cancel() _, err := m.InitializeTemplateDatabase(ctxt, hash) - if err != context.DeadlineExceeded { + if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("received unexpected error, got %v, want %v", err, context.DeadlineExceeded) } } @@ -173,7 +173,7 @@ func TestManagerInitializeTemplateDatabaseConcurrently(t *testing.T) { if err == nil { success++ } else { - if err == manager.ErrTemplateAlreadyInitialized { + if errors.Is(err, manager.ErrTemplateAlreadyInitialized) { failed++ } else { errored++ @@ -310,8 +310,8 @@ func TestManagerGetTestDatabaseExtendPool(t *testing.T) { cfg := manager.DefaultManagerConfigFromEnv() cfg.TestDatabaseGetTimeout = 300 * time.Millisecond - cfg.TestDatabaseInitialPoolSize = 0 // this will be autotransformed to 1 during init - cfg.TestDatabaseMaxPoolSize = 10 + cfg.PoolConfig.InitialPoolSize = 0 // this will be autotransformed to 1 during init + cfg.PoolConfig.MaxPoolSize = 10 m, _ := testManagerWithConfig(cfg) if err := m.Initialize(ctx); err != nil { @@ -335,7 +335,7 @@ func TestManagerGetTestDatabaseExtendPool(t *testing.T) { previousID := -1 // assert than one by one pool will be extended - for i := 0; i < cfg.TestDatabaseMaxPoolSize; i++ { + for i := 0; i < cfg.PoolConfig.MaxPoolSize; i++ { testDB, err := m.GetTestDatabase(ctx, hash) assert.NoError(t, err) assert.Equal(t, previousID+1, testDB.ID) @@ -384,7 +384,10 @@ func TestManagerFinalizeTemplateAndGetTestDatabaseConcurrently(t *testing.T) { return nil }) - g.Wait() + if err := g.Wait(); err != nil { + t.Fatal(err) + } + first := <-testCh assert.Equal(t, "FINALIZE", first) } @@ -506,7 +509,7 @@ func TestManagerDiscardTemplateDatabase(t *testing.T) { if err == nil { success++ } else { - // fmt.Println(err) + // t.Log(err) errored++ } } @@ -574,7 +577,7 @@ func TestManagerDiscardThenReinitializeTemplateDatabase(t *testing.T) { if err == nil { success++ } else { - // fmt.Println(err) + t.Log(err) errored++ } } @@ -606,8 +609,8 @@ func TestManagerGetAndReturnTestDatabase(t *testing.T) { ctx := context.Background() cfg := manager.DefaultManagerConfigFromEnv() - cfg.TestDatabaseInitialPoolSize = 3 - cfg.TestDatabaseMaxPoolSize = 3 + cfg.PoolConfig.InitialPoolSize = 3 + cfg.PoolConfig.MaxPoolSize = 3 cfg.TestDatabaseGetTimeout = 200 * time.Millisecond m, _ := testManagerWithConfig(cfg) @@ -631,7 +634,7 @@ func TestManagerGetAndReturnTestDatabase(t *testing.T) { } // request many more databases than initally added - for i := 0; i <= cfg.TestDatabaseMaxPoolSize*3; i++ { + for i := 0; i <= cfg.PoolConfig.MaxPoolSize*3; i++ { test, err := m.GetTestDatabase(ctx, hash) assert.NoError(t, err) assert.NotEmpty(t, test) @@ -648,9 +651,9 @@ func TestManagerGetAndRecreateTestDatabase(t *testing.T) { ctx := context.Background() cfg := manager.DefaultManagerConfigFromEnv() - cfg.TestDatabaseInitialPoolSize = 10 - cfg.TestDatabaseMaxPoolSize = 15 - cfg.TestDatabaseGetTimeout = 200 * time.Millisecond + cfg.PoolConfig.InitialPoolSize = 8 + cfg.PoolConfig.MaxPoolSize = 8 + cfg.TestDatabaseGetTimeout = 1000 * time.Millisecond m, _ := testManagerWithConfig(cfg) if err := m.Initialize(ctx); err != nil { @@ -673,8 +676,11 @@ func TestManagerGetAndRecreateTestDatabase(t *testing.T) { } // request many more databases than initally added - for i := 0; i <= cfg.TestDatabaseMaxPoolSize*3; i++ { + for i := 0; i <= cfg.PoolConfig.MaxPoolSize*5; i++ { test, err := m.GetTestDatabase(ctx, hash) + + t.Logf("open %v", test.ID) + assert.NoError(t, err) assert.NotEmpty(t, test) @@ -692,6 +698,8 @@ func TestManagerGetAndRecreateTestDatabase(t *testing.T) { require.NoError(t, err) assert.NoError(t, db.QueryRowContext(ctx, "SELECT COUNT(*) FROM pilots WHERE name = 'Anna'").Scan(&res)) assert.Equal(t, 1, res) + + t.Logf("close %v", test.ID) db.Close() // recreate testDB after usage @@ -707,9 +715,9 @@ func TestManagerGetTestDatabaseDontReturn(t *testing.T) { ctx := context.Background() cfg := manager.DefaultManagerConfigFromEnv() - cfg.TestDatabaseInitialPoolSize = 5 - cfg.TestDatabaseMaxPoolSize = 5 - cfg.TestDatabaseGetTimeout = time.Second + cfg.PoolConfig.InitialPoolSize = 5 + cfg.PoolConfig.MaxPoolSize = 5 + cfg.TestDatabaseGetTimeout = time.Second * 5 m, _ := testManagerWithConfig(cfg) if err := m.Initialize(ctx); err != nil { @@ -732,7 +740,7 @@ func TestManagerGetTestDatabaseDontReturn(t *testing.T) { } var wg sync.WaitGroup - for i := 0; i < cfg.TestDatabaseMaxPoolSize*5; i++ { + for i := 0; i < cfg.PoolConfig.MaxPoolSize*5; i++ { wg.Add(1) go func(i int) { defer wg.Done() @@ -777,8 +785,8 @@ func TestManagerReturnTestDatabase(t *testing.T) { ctx := context.Background() cfg := manager.DefaultManagerConfigFromEnv() - cfg.TestDatabaseInitialPoolSize = 1 - cfg.TestDatabaseMaxPoolSize = 10 + cfg.PoolConfig.InitialPoolSize = 1 + cfg.PoolConfig.MaxPoolSize = 10 cfg.TestDatabaseGetTimeout = 200 * time.Millisecond m, _ := testManagerWithConfig(cfg) @@ -814,19 +822,28 @@ func TestManagerReturnTestDatabase(t *testing.T) { // finally return it assert.NoError(t, m.ReturnTestDatabase(ctx, hash, testDB1.ID)) + // regetting these databases is quite random. Let's try to get the same id again... // on first GET call the pool has been extended // we will get the newly created DB testDB2, err := m.GetTestDatabase(ctx, hash) assert.NoError(t, err) - assert.NotEqual(t, testDB1.ID, testDB2.ID) // next in 'ready' channel should be the returned DB testDB3, err := m.GetTestDatabase(ctx, hash) assert.NoError(t, err) - assert.Equal(t, testDB1.ID, testDB3.ID) + + // restored db + var targetConnectionString string + if testDB2.ID == testDB1.ID { + targetConnectionString = testDB2.Config.ConnectionString() + } else if testDB3.ID == testDB1.ID { + targetConnectionString = testDB3.Config.ConnectionString() + } else { + t.Fatal("We should have been able to get the previously returned database.") + } // assert that it hasn't been cleaned but just reused directly - db, err = sql.Open("postgres", testDB3.Config.ConnectionString()) + db, err = sql.Open("postgres", targetConnectionString) require.NoError(t, err) require.NoError(t, db.PingContext(ctx)) @@ -873,7 +890,7 @@ func TestManagerReturnUntrackedTemplateDatabase(t *testing.T) { } id := 321 - dbName := fmt.Sprintf("%s_%s_%s_%d", config.DatabasePrefix, config.TestDatabasePrefix, hash, id) + dbName := fmt.Sprintf("%s_%s_%s_%d", config.DatabasePrefix, config.PoolConfig.TestDBNamePrefix, hash, id) if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s", pq.QuoteIdentifier(dbName))); err != nil { t.Fatalf("failed to manually drop template database %q: %v", dbName, err) @@ -882,8 +899,8 @@ func TestManagerReturnUntrackedTemplateDatabase(t *testing.T) { t.Fatalf("failed to manually create template database %q: %v", dbName, err) } - if err := m.ReturnTestDatabase(ctx, hash, id); err != nil { - t.Fatalf("failed to return manually created test database: %v", err) + if err := m.ReturnTestDatabase(ctx, hash, id); err == nil { + t.Fatalf("succeeded to return manually created test database: %v", err) // this should not work! } } @@ -976,7 +993,7 @@ func TestManagerClearTrackedTestDatabases(t *testing.T) { cfg := manager.DefaultManagerConfigFromEnv() // there are no db added in background - cfg.TestDatabaseInitialPoolSize = 0 + cfg.PoolConfig.InitialPoolSize = 0 m, _ := testManagerWithConfig(cfg) if err := m.Initialize(ctx); err != nil { diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index 25a6ae9..b34cae8 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -3,12 +3,13 @@ package pool import ( "context" "errors" - "fmt" "runtime/trace" "sync" "time" "github.com/allaboutapps/integresql/pkg/db" + "github.com/allaboutapps/integresql/pkg/util" + "github.com/rs/zerolog" ) var ( @@ -22,30 +23,40 @@ var ( type dbState int // Indicates a current DB state. const ( - dbStateReady dbState = iota // Initialized according to a template and ready to be picked up. - dbStateDirty // Taken by a client and potentially currently in use. + dbStateReady dbState = iota // Initialized according to a template and ready to be picked up. + dbStateDirty // Taken by a client and potentially currently in use. + dbStateRecreating // In the process of being recreated (to prevent concurrent cleans) ) -const minConcurrentTasksNum = 1 - type existingDB struct { state dbState db.TestDatabase + + // To prevent auto-cleans of a testdatabase on the dirty channel directly after it was issued as ready, + // each testdatabase gets a timestamp assigned after which auto-cleaning it generally allowed (unlock + // and recreate do not respect this). This timeout is typically very low and should only be neccessary + // to be tweaked in scenarios in which the pool is overloaded by requests. + // Prefer to tweak InitialPoolSize (the always ready dbs) and MaxPoolSize instead if you have issues here. + blockAutoCleanDirtyUntil time.Time + + // increased after each recreation, useful for sleepy recreating workers to check if we still operate on the same gen. + generation uint } type workerTask string const ( - workerTaskStop = "STOP" - workerTaskExtend = "EXTEND" - workerTaskCleanDirty = "CLEAN_DIRTY" + workerTaskStop = "STOP" + workerTaskExtend = "EXTEND" + workerTaskAutoCleanDirty = "CLEAN_DIRTY" ) // HashPool holds a test DB pool for a certain hash. Each HashPool is running cleanup workers in background. type HashPool struct { - dbs []existingDB - ready chan int // ID of initalized DBs according to a template, ready to pick them up - dirty chan int // ID of DBs that were given away and need to be recreated to reuse them + dbs []existingDB + ready chan int // ID of initalized DBs according to a template, ready to pick them up + dirty chan int // ID of DBs that were given away and need to be recreated to reuse them + recreating chan struct{} // tracks currently running recreating ops recreateDB recreateTestDBFunc templateDB db.Database @@ -54,28 +65,26 @@ type HashPool struct { sync.RWMutex wg sync.WaitGroup - tasksChan chan string - running bool + tasksChan chan workerTask + running bool + workerContext context.Context // the ctx all background workers will receive (nil if not yet started) } // NewHashPool creates new hash pool with the given config. // Starts the workers to extend the pool in background up to requested inital number. func NewHashPool(cfg PoolConfig, templateDB db.Database, initDBFunc RecreateDBFunc) *HashPool { - if cfg.PoolMaxParallelTasks < minConcurrentTasksNum { - cfg.PoolMaxParallelTasks = minConcurrentTasksNum - } - pool := &HashPool{ - dbs: make([]existingDB, 0, cfg.MaxPoolSize), - ready: make(chan int, cfg.MaxPoolSize), - dirty: make(chan int, cfg.MaxPoolSize), + dbs: make([]existingDB, 0, cfg.MaxPoolSize), + ready: make(chan int, cfg.MaxPoolSize), + dirty: make(chan int, cfg.MaxPoolSize), + recreating: make(chan struct{}, cfg.MaxPoolSize), recreateDB: makeActualRecreateTestDBFunc(templateDB.Config.Database, initDBFunc), templateDB: templateDB, PoolConfig: cfg, - tasksChan: make(chan string, cfg.MaxPoolSize+1), + tasksChan: make(chan workerTask, cfg.MaxPoolSize+1), running: false, } @@ -83,14 +92,23 @@ func NewHashPool(cfg PoolConfig, templateDB db.Database, initDBFunc RecreateDBFu } func (pool *HashPool) Start() { + + log := pool.getPoolLogger(context.Background(), "Start") pool.Lock() + log.Debug().Msg("starting...") + defer pool.Unlock() if pool.running { + log.Warn().Msg("bailout already running!") return } pool.running = true + + ctx, cancel := context.WithCancel(context.Background()) + pool.workerContext = ctx + for i := 0; i < pool.InitialPoolSize; i++ { pool.tasksChan <- workerTaskExtend } @@ -98,13 +116,20 @@ func (pool *HashPool) Start() { pool.wg.Add(1) go func() { defer pool.wg.Done() - pool.controlLoop() + pool.controlLoop(ctx, cancel) }() + + log.Info().Msg("started!") } func (pool *HashPool) Stop() { + + log := pool.getPoolLogger(context.Background(), "Stop") + log.Debug().Msg("stopping...") + pool.Lock() if !pool.running { + log.Warn().Msg("bailout already stopped!") return } pool.running = false @@ -112,24 +137,30 @@ func (pool *HashPool) Stop() { pool.tasksChan <- workerTaskStop pool.wg.Wait() + pool.workerContext = nil + log.Warn().Msg("stopped!") } -func (pool *HashPool) GetTestDatabase(ctx context.Context, hash string, timeout time.Duration) (db db.TestDatabase, err error) { +func (pool *HashPool) GetTestDatabase(ctx context.Context, timeout time.Duration) (db db.TestDatabase, err error) { var index int - // fmt.Printf("pool#%s: waiting for ready ID...\n", hash) + log := pool.getPoolLogger(ctx, "GetTestDatabase") + log.Trace().Msg("waiting for ready ID...") select { case <-time.After(timeout): err = ErrTimeout + log.Error().Err(err).Dur("timeout", timeout).Msg("timeout") return case <-ctx.Done(): err = ctx.Err() + log.Warn().Err(err).Msg("ctx done") return case index = <-pool.ready: } - // fmt.Printf("pool#%s: got ready ID=%v\n", hash, index) + log = log.With().Int("id", index).Logger() + log.Trace().Msg("got ready testdatabase!") reg := trace.StartRegion(ctx, "wait_for_lock_hash_pool") pool.Lock() @@ -139,96 +170,104 @@ func (pool *HashPool) GetTestDatabase(ctx context.Context, hash string, timeout // sanity check, should never happen if index < 0 || index >= len(pool.dbs) { err = ErrInvalidIndex + log.Error().Err(err).Int("dbs", len(pool.dbs)).Msg("index out of bounds!") return } testDB := pool.dbs[index] // sanity check, should never happen - we got this index from 'ready' channel if testDB.state != dbStateReady { - - // fmt.Printf("pool#%s: GetTestDatabase ErrInvalidState ID=%v\n", hash, index) - err = ErrInvalidState + log.Error().Err(err).Msgf("testdatabase is not in ready state=%v!", testDB.state) return } + // flag as dirty and block auto clean until testDB.state = dbStateDirty + testDB.blockAutoCleanDirtyUntil = time.Now().Add(pool.TestDatabaseMinimalLifetime) + pool.dbs[index] = testDB pool.dirty <- index if len(pool.dbs) < pool.PoolConfig.MaxPoolSize { + log.Trace().Msg("push workerTaskExtend") pool.tasksChan <- workerTaskExtend } // we try to ensure that InitialPoolSize count is staying ready - // thus, we try to move the oldest dirty dbs into cleaning - if len(pool.dbs) >= pool.PoolConfig.MaxPoolSize { - pool.tasksChan <- workerTaskCleanDirty + // thus, we try to move the oldest dirty dbs into recreating with the workerTaskAutoCleanDirty + if len(pool.dbs) >= pool.PoolConfig.MaxPoolSize && (len(pool.ready)+len(pool.recreating)) < pool.InitialPoolSize { + log.Trace().Msg("push workerTaskAutoCleanDirty") + pool.tasksChan <- workerTaskAutoCleanDirty } - // fmt.Printf("pool#%s: ready=%d, dirty=%d, waitingForCleaning=%d, dbs=%d initial=%d max=%d (GetTestDatabase)\n", hash, len(pool.ready), len(pool.dirty), len(pool.waitingForCleaning), len(pool.dbs), pool.PoolConfig.InitialPoolSize, pool.PoolConfig.MaxPoolSize) + pool.unsafeTraceLogStats(log) return testDB.TestDatabase, nil } -func (pool *HashPool) AddTestDatabase(ctx context.Context, templateDB db.Database) error { - return pool.extend(ctx) -} +func (pool *HashPool) workerTaskLoop(ctx context.Context, taskChan <-chan workerTask, MaxParallelTasks int) { -func (pool *HashPool) workerTaskLoop(ctx context.Context, taskChan <-chan string, poolMaxParallelTasks int) { + log := pool.getPoolLogger(ctx, "workerTaskLoop") + log.Debug().Msg("starting...") - handlers := map[string]func(ctx context.Context) error{ - workerTaskExtend: ignoreErrs(pool.extend, ErrPoolFull, context.Canceled), - workerTaskCleanDirty: ignoreErrs(pool.cleanDirty, context.Canceled), + handlers := map[workerTask]func(ctx context.Context) error{ + workerTaskExtend: ignoreErrs(pool.extend, ErrPoolFull, context.Canceled), + workerTaskAutoCleanDirty: ignoreErrs(pool.autoCleanDirty, context.Canceled), } // to limit the number of running goroutines. - var semaphore = make(chan struct{}, poolMaxParallelTasks) + var semaphore = make(chan struct{}, MaxParallelTasks) for task := range taskChan { handler, ok := handlers[task] if !ok { - fmt.Printf("invalid task: %s", task) + log.Error().Msgf("invalid task: %s", task) continue } select { case <-ctx.Done(): + log.Warn().Err(ctx.Err()).Msg("ctx done!") return case semaphore <- struct{}{}: } pool.wg.Add(1) - go func(task string) { + go func(task workerTask) { defer func() { pool.wg.Done() <-semaphore }() - // fmt.Println("task", task) + log.Debug().Msgf("task=%v", task) + if err := handler(ctx); err != nil { - fmt.Println("task", task, "failed:", err.Error()) + log.Error().Err(err).Msgf("task=%v FAILED!", task) } }(task) } } -func (pool *HashPool) controlLoop() { +func (pool *HashPool) controlLoop(ctx context.Context, cancel context.CancelFunc) { + + log := pool.getPoolLogger(ctx, "controlLoop") + log.Debug().Msg("starting...") - ctx, cancel := context.WithCancel(context.Background()) defer cancel() - workerTasksChan := make(chan string, len(pool.tasksChan)) + workerTasksChan := make(chan workerTask, len(pool.tasksChan)) pool.wg.Add(1) go func() { defer pool.wg.Done() - pool.workerTaskLoop(ctx, workerTasksChan, pool.PoolMaxParallelTasks) + pool.workerTaskLoop(ctx, workerTasksChan, pool.MaxParallelTasks) }() for task := range pool.tasksChan { if task == workerTaskStop { + log.Debug().Msg("stopping...") close(workerTasksChan) cancel() return @@ -244,17 +283,29 @@ func (pool *HashPool) controlLoop() { } // ReturnTestDatabase returns the given test DB directly to the pool, without cleaning (recreating it). -func (pool *HashPool) ReturnTestDatabase(ctx context.Context, hash string, id int) error { +func (pool *HashPool) ReturnTestDatabase(ctx context.Context, id int) error { + + log := pool.getPoolLogger(ctx, "ReturnTestDatabase").With().Int("id", id).Logger() + log.Debug().Msg("returning...") + pool.Lock() defer pool.Unlock() + if err := ctx.Err(); err != nil { + // client vanished + log.Warn().Err(err).Msg("bailout client vanished!") + return err + } + if id < 0 || id >= len(pool.dbs) { + log.Warn().Int("dbs", len(pool.dbs)).Msg("bailout invalid index!") return ErrInvalidIndex } // check if db is in the correct state testDB := pool.dbs[id] - if testDB.state == dbStateReady { + if testDB.state != dbStateDirty { + log.Warn().Int("dbs", len(pool.dbs)).Msgf("bailout invalid state=%v.", testDB.state) return nil } @@ -262,51 +313,181 @@ func (pool *HashPool) ReturnTestDatabase(ctx context.Context, hash string, id in testDB.state = dbStateReady pool.dbs[id] = testDB + // remove id from dirty and add it to ready channel + pool.excludeIDFromChannel(pool.dirty, id) pool.ready <- id + pool.unsafeTraceLogStats(log) + return nil +} + +func (pool *HashPool) excludeIDFromChannel(ch chan int, excludeID int) { + + // The testDB identified by overgiven id may still in a specific channel (typically dirty). We want to exclude it. + // We need to explicitly remove it from there by filtering the current channel to a tmp channel. + // We finally close the tmp channel and flush it onto the specific channel again. + // The id is now no longer in the channel. + filtered := make(chan int, pool.MaxPoolSize) + var id int + for loop := true; loop; { + select { + case id = <-ch: + if id != excludeID { + filtered <- id + } + default: + loop = false + break + } + } + + // filtered now has all filtered values without the above id, redirect the other ids back to the specific channel. + // close so we can range over it... + close(filtered) + + for id := range filtered { + ch <- id + } } -// RecreateTestDatabase recreates the test DB according to the template and returns it back to the pool. -func (pool *HashPool) RecreateTestDatabase(ctx context.Context, hash string, id int) error { +// RecreateTestDatabase prioritizes the test DB to be recreated next via the dirty worker. +func (pool *HashPool) RecreateTestDatabase(ctx context.Context, id int) error { + + log := pool.getPoolLogger(ctx, "RecreateTestDatabase").With().Int("id", id).Logger() + log.Debug().Msg("flag testdatabase for recreation...") pool.RLock() + if id < 0 || id >= len(pool.dbs) { + log.Warn().Int("dbs", len(pool.dbs)).Msg("bailout invalid index!") pool.RUnlock() return ErrInvalidIndex } - // check if db is in the correct state - testDB := pool.dbs[id] pool.RUnlock() - if testDB.state == dbStateReady { - return nil + if err := ctx.Err(); err != nil { + // client vanished + log.Warn().Err(err).Msg("bailout client vanished!") + return err } - // state is dirty -> we will now recreate it - if err := pool.recreateDB(ctx, &testDB); err != nil { + // exclude from the normal dirty channel, force recreation in a background worker... + pool.excludeIDFromChannel(pool.dirty, id) + + // directly spawn a new worker in the bg (with the same ctx as the typical workers) + // note that this runs unchained, meaning we do not care about errors that may happen via this bg task + //nolint:errcheck + go pool.recreateDatabaseGracefully(pool.workerContext, id) + + pool.unsafeTraceLogStats(log) + return nil +} + +// recreateDatabaseGracefully continuosly tries to recreate the testdatabase and will retry/block until it succeeds +func (pool *HashPool) recreateDatabaseGracefully(ctx context.Context, id int) error { + + log := pool.getPoolLogger(ctx, "recreateDatabaseGracefully").With().Int("id", id).Logger() + log.Debug().Msg("recreating...") + + if err := ctx.Err(); err != nil { + // pool closed in the meantime. + log.Error().Err(err).Msg("bailout pre locking ctx err") return err } pool.Lock() - defer pool.Unlock() - // change the state to 'ready' - testDB.state = dbStateReady + if state := pool.dbs[id].state; state != dbStateDirty { + // nothing to do + log.Error().Msgf("bailout not dbStateDirty state=%v", state) + pool.Unlock() + return nil + } + + testDB := pool.dbs[id] + + // set state recreating... + pool.dbs[id].state = dbStateRecreating pool.dbs[id] = testDB - pool.ready <- id + pool.Unlock() - return nil + pool.recreating <- struct{}{} + + defer func() { + <-pool.recreating + }() + + try := 0 + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + try++ + + log.Trace().Int("try", try).Msg("trying to recreate...") + err := pool.recreateDB(ctx, &testDB) + if err != nil { + // only still connected errors are worthy a retry + if errors.Is(err, ErrTestDBInUse) { + + backoff := time.Duration(try) * pool.PoolConfig.TestDatabaseRetryRecreateSleepMin + if backoff > pool.PoolConfig.TestDatabaseRetryRecreateSleepMax { + backoff = pool.PoolConfig.TestDatabaseRetryRecreateSleepMax + } + + log.Warn().Int("try", try).Dur("backoff", backoff).Msg("DB is still in use, will retry...") + time.Sleep(backoff) + } else { + + log.Error().Int("try", try).Err(err).Msg("bailout worker task DB error while cleanup!") + return err + } + } else { + goto MoveToReady + } + } + } + +MoveToReady: + pool.Lock() + defer pool.Unlock() + + if ctx.Err() != nil { + // pool closed in the meantime. + return ctx.Err() + } + + if pool.dbs[id].state == dbStateReady { + // oups, it has been cleaned by another worker already + // we won't add it to the 'ready' channel to avoid duplication + log.Warn().Msg("bailout DB has be cleaned by another worker as its already ready, skipping readd to ready channel!") + return nil + } + + // increase the generation of the testdb (as we just recreated it) and move into ready! + pool.dbs[id].generation++ + pool.dbs[id].state = dbStateReady + + pool.ready <- pool.dbs[id].ID + + log.Debug().Uint("generation", pool.dbs[id].generation).Msg("ready") + pool.unsafeTraceLogStats(log) + return nil } -// cleanDirty reads 'dirty' channel and cleans up a test DB with the received index. +// autoCleanDirty reads 'dirty' channel and cleans up a test DB with the received index. // When the DB is recreated according to a template, its index goes to the 'ready' channel. -// The function waits until there is a dirty DB... -func (pool *HashPool) cleanDirty(ctx context.Context) error { +// Note that we generally gurantee FIFO when it comes to auto-cleaning as long as no manual unlock/recreates happen. +func (pool *HashPool) autoCleanDirty(ctx context.Context) error { + + log := pool.getPoolLogger(ctx, "autoCleanDirty") + log.Trace().Msg("autocleaning...") ctx, task := trace.NewTask(ctx, "worker_clean_dirty") defer task.End() @@ -318,64 +499,57 @@ func (pool *HashPool) cleanDirty(ctx context.Context) error { return ctx.Err() default: // nothing to do + log.Trace().Msg("noop") return nil } + // got id... + log = log.With().Int("id", id).Logger() + log.Trace().Msg("checking cleaning prerequisites...") + regLock := trace.StartRegion(ctx, "worker_wait_for_rlock_hash_pool") pool.RLock() regLock.End() if id < 0 || id >= len(pool.dbs) { // sanity check, should never happen + log.Warn().Int("dbs", len(pool.dbs)).Msg("bailout invalid index!") pool.RUnlock() return ErrInvalidIndex } - testDB := pool.dbs[id] - pool.RUnlock() - - if testDB.state == dbStateReady { - // nothing to do - return nil - } - - reg := trace.StartRegion(ctx, "worker_db_operation") - err := pool.recreateDB(ctx, &testDB) - reg.End() - if err != nil { - // fmt.Printf("worker_clean_dirty: failed to clean up DB ID='%v': %v\n", id, err) + blockedUntil := time.Until(pool.dbs[id].blockAutoCleanDirtyUntil) + generation := pool.dbs[id].generation - // we guarantee FIFO, we must keeping trying to clean up **exactly this** test database! - if errors.Is(err, ErrTestDBInUse) { + log = log.With().Dur("blockedUntil", blockedUntil).Uint("generation", generation).Logger() - fmt.Printf("worker_clean_dirty: scheduling retry cleanup for ID='%v'...\n", id) - time.Sleep(250 * time.Millisecond) - fmt.Printf("integworker_clean_dirtyresql: push DB ID='%v' into retry.", id) - pool.dirty <- id - pool.tasksChan <- workerTaskCleanDirty - return nil - } + pool.RUnlock() - return err + // immediately pass to pool recreate + if blockedUntil <= 0 { + log.Trace().Msg("clean now (immediate)!") + return pool.recreateDatabaseGracefully(ctx, id) } - regLock = trace.StartRegion(ctx, "worker_wait_for_lock_hash_pool") - pool.Lock() - defer pool.Unlock() - regLock.End() + // else we need to wait until we are allowed to work with it! + // we block auto-cleaning until we are allowed to... + log.Warn().Msg("sleeping before being allowed to clean...") + time.Sleep(blockedUntil) - if testDB.state == dbStateReady { - // oups, it has been cleaned by another worker already - // we won't add it to the 'ready' channel to avoid duplication + // we need to check that the testDB.generation did not change since we slept + // (which would indicate that the database was already unlocked/recreated by someone else in the meantime) + pool.RLock() + + if pool.dbs[id].generation != generation || pool.dbs[id].state != dbStateDirty { + log.Error().Msgf("bailout old generation=%v vs new generation=%v state=%v", generation, pool.dbs[id].generation, pool.dbs[id].state) + pool.RUnlock() return nil } - testDB.state = dbStateReady - pool.dbs[id] = testDB - - pool.ready <- testDB.ID + pool.RUnlock() - return nil + log.Trace().Msg("clean now (after sleep has happenend)!") + return pool.recreateDatabaseGracefully(ctx, id) } func ignoreErrs(f func(ctx context.Context) error, errs ...error) func(context.Context) error { @@ -392,23 +566,27 @@ func ignoreErrs(f func(ctx context.Context) error, errs ...error) func(context.C func (pool *HashPool) extend(ctx context.Context) error { + log := pool.getPoolLogger(ctx, "extend") + log.Trace().Msg("extending...") + ctx, task := trace.NewTask(ctx, "worker_extend") defer task.End() reg := trace.StartRegion(ctx, "worker_wait_for_lock_hash_pool") pool.Lock() - defer pool.Unlock() reg.End() // get index of a next test DB - its ID index := len(pool.dbs) if index == cap(pool.dbs) { + log.Error().Int("dbs", len(pool.dbs)).Int("cap", cap(pool.dbs)).Err(ErrPoolFull).Msg("pool is full") + pool.Unlock() return ErrPoolFull } - // initalization of a new DB using template config + // initalization of a new DB using template config, it must start in state dirty! newTestDB := existingDB{ - state: dbStateReady, + state: dbStateDirty, TestDatabase: db.TestDatabase{ Database: db.Database{ TemplateHash: pool.templateDB.TemplateHash, @@ -420,33 +598,31 @@ func (pool *HashPool) extend(ctx context.Context) error { // set DB name newTestDB.Database.Config.Database = makeDBName(pool.TestDBNamePrefix, pool.templateDB.TemplateHash, index) - reg = trace.StartRegion(ctx, "worker_db_operation") - err := pool.recreateDB(ctx, &newTestDB) - reg.End() - - if err != nil { - return err - } - - // add new test DB to the pool + // add new test DB to the pool (currently it's dirty!) pool.dbs = append(pool.dbs, newTestDB) - pool.ready <- newTestDB.ID + log.Trace().Int("id", index).Msg("appended as dirty, recreating...") + pool.unsafeTraceLogStats(log) + pool.Unlock() - return nil + // forced recreate... + return pool.recreateDatabaseGracefully(ctx, index) } func (pool *HashPool) RemoveAll(ctx context.Context, removeFunc RemoveDBFunc) error { + log := pool.getPoolLogger(ctx, "RemoveAll") + // stop all workers pool.Stop() - // ! - // HashPool locked + // wait until all current "recreating" tasks are finished... + pool.Lock() defer pool.Unlock() if len(pool.dbs) == 0 { + log.Error().Msg("bailout no dbs.") return nil } @@ -455,19 +631,33 @@ func (pool *HashPool) RemoveAll(ctx context.Context, removeFunc RemoveDBFunc) er testDB := pool.dbs[id].TestDatabase if err := removeFunc(ctx, testDB); err != nil { + log.Error().Int("id", id).Err(err).Msg("removeFunc testdatabase err") return err } if len(pool.dbs) > 1 { pool.dbs = pool.dbs[:len(pool.dbs)-1] } + + pool.excludeIDFromChannel(pool.dirty, id) + pool.excludeIDFromChannel(pool.ready, id) + log.Debug().Int("id", id).Msg("testdatabase removed!") } // close all only if removal of all succeeded pool.dbs = nil close(pool.tasksChan) + pool.unsafeTraceLogStats(log) + return nil - // HashPool unlocked - // ! +} + +func (pool *HashPool) getPoolLogger(ctx context.Context, poolFunction string) zerolog.Logger { + return util.LogFromContext(ctx).With().Str("poolHash", pool.templateDB.TemplateHash).Str("poolFn", poolFunction).Logger() +} + +// unsafeTraceLogStats logs stats of this pool. Attention: pool should be read or write locked! +func (pool *HashPool) unsafeTraceLogStats(log zerolog.Logger) { + log.Trace().Int("ready", len(pool.ready)).Int("dirty", len(pool.dirty)).Int("recreating", len(pool.recreating)).Int("tasksChan", len(pool.tasksChan)).Int("dbs", len(pool.dbs)).Int("initial", pool.PoolConfig.InitialPoolSize).Int("max", pool.PoolConfig.MaxPoolSize).Msg("pool stats") } diff --git a/pkg/pool/pool_collection.go b/pkg/pool/pool_collection.go index 6aad7e2..82db735 100644 --- a/pkg/pool/pool_collection.go +++ b/pkg/pool/pool_collection.go @@ -13,14 +13,21 @@ import ( var ErrUnknownHash = errors.New("no database pool exists for this hash") -type PoolConfig struct { - MaxPoolSize int - InitialPoolSize int - TestDBNamePrefix string - PoolMaxParallelTasks int +// we explicitly want to access this struct via pool.PoolConfig, thus we disable revive for the next line +type PoolConfig struct { //nolint:revive + InitialPoolSize int // Initial number of ready DBs prepared in background + MaxPoolSize int // Maximal pool size that won't be exceeded + TestDBNamePrefix string // Test-Database prefix: DatabasePrefix_TestDBNamePrefix_HASH_ID + MaxParallelTasks int // Maximal number of pool tasks running in parallel. Must be a number greater or equal 1. + TestDatabaseRetryRecreateSleepMin time.Duration // Minimal time to wait after a test db recreate has failed (e.g. as client is still connected). Subsequent retries multiply this values until... + TestDatabaseRetryRecreateSleepMax time.Duration // ... the maximum possible sleep time between retries (e.g. 3 seconds) is reached. + TestDatabaseMinimalLifetime time.Duration // After a testdatabase transitions from ready to dirty, always block auto-recreation for this duration (except manual recreate). + + disableWorkerAutostart bool // test only private flag for starting without background worker task system } -type PoolCollection struct { +// we explicitly want to access this struct via pool.PoolCollection, thus we disable revive for the next line +type PoolCollection struct { //nolint:revive PoolConfig pools map[string]*HashPool // map[hash] @@ -51,7 +58,7 @@ func makeActualRecreateTestDBFunc(templateName string, userRecreateFunc Recreate type recreateTestDBFunc func(context.Context, *existingDB) error // InitHashPool creates a new pool with a given template hash and starts the cleanup workers. -func (p *PoolCollection) InitHashPool(ctx context.Context, templateDB db.Database, initDBFunc RecreateDBFunc) { +func (p *PoolCollection) InitHashPool(_ context.Context, templateDB db.Database, initDBFunc RecreateDBFunc) { p.mutex.Lock() defer p.mutex.Unlock() @@ -59,12 +66,25 @@ func (p *PoolCollection) InitHashPool(ctx context.Context, templateDB db.Databas // Create a new HashPool pool := NewHashPool(cfg, templateDB, initDBFunc) - pool.Start() + + if !cfg.disableWorkerAutostart { + pool.Start() + } // pool is ready p.pools[pool.templateDB.TemplateHash] = pool } +// Start is used to start all background workers +func (p *PoolCollection) Start() { + p.mutex.RLock() + defer p.mutex.RUnlock() + + for _, pool := range p.pools { + pool.Start() + } +} + // Stop is used to stop all background workers func (p *PoolCollection) Stop() { p.mutex.RLock() @@ -73,7 +93,6 @@ func (p *PoolCollection) Stop() { for _, pool := range p.pools { pool.Stop() } - } // GetTestDatabase picks up a ready to use test DB. It waits the given timeout until a DB is available. @@ -86,21 +105,7 @@ func (p *PoolCollection) GetTestDatabase(ctx context.Context, hash string, timeo return db, err } - return pool.GetTestDatabase(ctx, hash, timeout) -} - -// AddTestDatabase adds a new test DB to the pool and creates it according to the template. -// The new test DB is marked as 'Ready' and can be picked up with GetTestDatabase. -// If the pool size has already reached MAX, ErrPoolFull is returned. -func (p *PoolCollection) AddTestDatabase(ctx context.Context, templateDB db.Database) error { - hash := templateDB.TemplateHash - - pool, err := p.getPool(ctx, hash) - if err != nil { - return err - } - - return pool.AddTestDatabase(ctx, templateDB) + return pool.GetTestDatabase(ctx, timeout) } // ReturnTestDatabase returns the given test DB directly to the pool, without cleaning (recreating it). @@ -110,7 +115,7 @@ func (p *PoolCollection) ReturnTestDatabase(ctx context.Context, hash string, id return err } - return pool.ReturnTestDatabase(ctx, hash, id) + return pool.ReturnTestDatabase(ctx, id) } // RecreateTestDatabase recreates the test DB according to the template and returns it back to the pool. @@ -120,7 +125,7 @@ func (p *PoolCollection) RecreateTestDatabase(ctx context.Context, hash string, return err } - return pool.RecreateTestDatabase(ctx, hash, id) + return pool.RecreateTestDatabase(ctx, id) } // RemoveAllWithHash removes a pool with a given template hash. @@ -201,3 +206,18 @@ func (p *PoolCollection) getPoolLockCollection(ctx context.Context, hash string) return pool, unlock, err } + +// extend is only used for internal testing! +// it adds a new test DB to the pool and creates it according to the template. +// The new test DB is marked as 'Ready' and can be picked up with GetTestDatabase. +// If the pool size has already reached MAX, ErrPoolFull is returned. +func (p *PoolCollection) extend(ctx context.Context, templateDB db.Database) error { + hash := templateDB.TemplateHash + + pool, err := p.getPool(ctx, hash) + if err != nil { + return err + } + + return pool.extend(ctx) +} diff --git a/pkg/pool/pool_collection_test.go b/pkg/pool/pool_collection_internal_test.go similarity index 75% rename from pkg/pool/pool_collection_test.go rename to pkg/pool/pool_collection_internal_test.go index c89b42f..a843d0c 100644 --- a/pkg/pool/pool_collection_test.go +++ b/pkg/pool/pool_collection_internal_test.go @@ -1,4 +1,4 @@ -package pool_test +package pool import ( "context" @@ -7,7 +7,6 @@ import ( "time" "github.com/allaboutapps/integresql/pkg/db" - "github.com/allaboutapps/integresql/pkg/pool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,12 +15,13 @@ func TestPoolAddGet(t *testing.T) { t.Parallel() ctx := context.Background() - cfg := pool.PoolConfig{ - MaxPoolSize: 2, - PoolMaxParallelTasks: 4, - TestDBNamePrefix: "prefix_", + cfg := PoolConfig{ + MaxPoolSize: 2, + MaxParallelTasks: 4, + TestDBNamePrefix: "prefix_", + disableWorkerAutostart: true, // no extend / cleanDirty tasks should run automatically! } - p := pool.NewPoolCollection(cfg) + p := NewPoolCollection(cfg) hash1 := "h1" hash2 := "h2" @@ -37,16 +37,17 @@ func TestPoolAddGet(t *testing.T) { return nil } p.InitHashPool(ctx, templateDB, initFunc) + t.Cleanup(func() { p.Stop() }) - // get from empty + // get from empty (just initialized) _, err := p.GetTestDatabase(ctx, hash1, 0) - assert.Error(t, err, pool.ErrTimeout) + assert.Error(t, err, ErrTimeout) // add a new one - assert.NoError(t, p.AddTestDatabase(ctx, templateDB)) + assert.NoError(t, p.extend(ctx, templateDB)) // get it - testDB, err := p.GetTestDatabase(ctx, hash1, 100*time.Millisecond) + testDB, err := p.GetTestDatabase(ctx, hash1, 1*time.Second) assert.NoError(t, err) assert.Equal(t, "prefix_h1_000", testDB.Database.Config.Database) assert.Equal(t, "ich", testDB.Database.Config.Username) @@ -55,19 +56,19 @@ func TestPoolAddGet(t *testing.T) { templateDB2 := templateDB templateDB2.TemplateHash = hash2 p.InitHashPool(ctx, templateDB2, initFunc) - assert.NoError(t, p.AddTestDatabase(ctx, templateDB2)) - assert.NoError(t, p.AddTestDatabase(ctx, templateDB2)) - assert.ErrorIs(t, p.AddTestDatabase(ctx, templateDB2), pool.ErrPoolFull) + assert.NoError(t, p.extend(ctx, templateDB2)) + assert.NoError(t, p.extend(ctx, templateDB2)) + assert.ErrorIs(t, p.extend(ctx, templateDB2), ErrPoolFull) // get from empty h1 - _, err = p.GetTestDatabase(ctx, hash1, 0) - assert.Error(t, err, pool.ErrTimeout) + _, err = p.GetTestDatabase(ctx, hash1, 100*time.Millisecond) + assert.ErrorIs(t, err, ErrTimeout) // get from h2 - testDB1, err := p.GetTestDatabase(ctx, hash2, 0) + testDB1, err := p.GetTestDatabase(ctx, hash2, 1*time.Second) assert.NoError(t, err) assert.Equal(t, hash2, testDB1.TemplateHash) - testDB2, err := p.GetTestDatabase(ctx, hash2, 0) + testDB2, err := p.GetTestDatabase(ctx, hash2, 1*time.Second) assert.NoError(t, err) assert.Equal(t, hash2, testDB2.TemplateHash) assert.NotEqual(t, testDB1.ID, testDB2.ID) @@ -91,13 +92,13 @@ func TestPoolAddGetConcurrent(t *testing.T) { } maxPoolSize := 15 - cfg := pool.PoolConfig{ - MaxPoolSize: maxPoolSize, - InitialPoolSize: maxPoolSize, - PoolMaxParallelTasks: 4, - TestDBNamePrefix: "", + cfg := PoolConfig{ + MaxPoolSize: maxPoolSize, + InitialPoolSize: maxPoolSize, + MaxParallelTasks: 4, + TestDBNamePrefix: "", } - p := pool.NewPoolCollection(cfg) + p := NewPoolCollection(cfg) t.Cleanup(func() { p.Stop() }) var wg sync.WaitGroup @@ -146,12 +147,12 @@ func TestPoolAddGetReturnConcurrent(t *testing.T) { return nil } - cfg := pool.PoolConfig{ - MaxPoolSize: 40, - PoolMaxParallelTasks: 4, - TestDBNamePrefix: "", + cfg := PoolConfig{ + MaxPoolSize: 40, + MaxParallelTasks: 4, + TestDBNamePrefix: "", } - p := pool.NewPoolCollection(cfg) + p := NewPoolCollection(cfg) t.Cleanup(func() { p.Stop() }) p.InitHashPool(ctx, templateDB1, initFunc) @@ -161,8 +162,8 @@ func TestPoolAddGetReturnConcurrent(t *testing.T) { // add DBs sequentially for i := 0; i < cfg.MaxPoolSize/4; i++ { - assert.NoError(t, p.AddTestDatabase(ctx, templateDB1)) - assert.NoError(t, p.AddTestDatabase(ctx, templateDB2)) + assert.NoError(t, p.extend(ctx, templateDB1)) + assert.NoError(t, p.extend(ctx, templateDB2)) } // stop the workers to prevent auto cleaning in background @@ -209,11 +210,11 @@ func TestPoolRemoveAll(t *testing.T) { return nil } - cfg := pool.PoolConfig{ - MaxPoolSize: 6, - PoolMaxParallelTasks: 4, + cfg := PoolConfig{ + MaxPoolSize: 6, + MaxParallelTasks: 4, } - p := pool.NewPoolCollection(cfg) + p := NewPoolCollection(cfg) t.Cleanup(func() { p.Stop() }) p.InitHashPool(ctx, templateDB1, initFunc) @@ -221,8 +222,8 @@ func TestPoolRemoveAll(t *testing.T) { // add DBs sequentially for i := 0; i < cfg.MaxPoolSize; i++ { - assert.NoError(t, p.AddTestDatabase(ctx, templateDB1)) - assert.NoError(t, p.AddTestDatabase(ctx, templateDB2)) + assert.NoError(t, p.extend(ctx, templateDB1)) + assert.NoError(t, p.extend(ctx, templateDB2)) } // remove all @@ -230,14 +231,14 @@ func TestPoolRemoveAll(t *testing.T) { // try to get _, err := p.GetTestDatabase(ctx, hash1, 0) - assert.Error(t, err, pool.ErrTimeout) + assert.Error(t, err, ErrTimeout) _, err = p.GetTestDatabase(ctx, hash2, 0) - assert.Error(t, err, pool.ErrTimeout) + assert.Error(t, err, ErrTimeout) // start using pool again p.InitHashPool(ctx, templateDB1, initFunc) - assert.NoError(t, p.AddTestDatabase(ctx, templateDB1)) - testDB, err := p.GetTestDatabase(ctx, hash1, 0) + assert.NoError(t, p.extend(ctx, templateDB1)) + testDB, err := p.GetTestDatabase(ctx, hash1, 1*time.Second) assert.NoError(t, err) assert.Equal(t, 0, testDB.ID) } @@ -260,19 +261,19 @@ func TestPoolReuseDirty(t *testing.T) { } maxPoolSize := 40 - cfg := pool.PoolConfig{ - MaxPoolSize: maxPoolSize, - InitialPoolSize: maxPoolSize, - PoolMaxParallelTasks: 1, - TestDBNamePrefix: "test_", + cfg := PoolConfig{ + MaxPoolSize: maxPoolSize, + InitialPoolSize: maxPoolSize, + MaxParallelTasks: 1, + TestDBNamePrefix: "test_", } - p := pool.NewPoolCollection(cfg) + p := NewPoolCollection(cfg) p.InitHashPool(ctx, templateDB1, initFunc) t.Cleanup(func() { p.Stop() }) getDirty := func(seenIDMap *sync.Map) { - newTestDB1, err := p.GetTestDatabase(ctx, templateDB1.TemplateHash, 1*time.Second) + newTestDB1, err := p.GetTestDatabase(ctx, templateDB1.TemplateHash, 3*time.Second) assert.NoError(t, err) seenIDMap.Store(newTestDB1.ID, true) } @@ -319,26 +320,23 @@ func TestPoolReturnTestDatabase(t *testing.T) { return nil } - cfg := pool.PoolConfig{ - MaxPoolSize: 10, - PoolMaxParallelTasks: 3, + cfg := PoolConfig{ + MaxPoolSize: 10, + MaxParallelTasks: 3, + disableWorkerAutostart: true, // no extend / cleanDirty tasks should run automatically! } - p := pool.NewPoolCollection(cfg) - t.Cleanup(func() { p.Stop() }) + p := NewPoolCollection(cfg) p.InitHashPool(ctx, templateDB1, initFunc) // add just one test DB - require.NoError(t, p.AddTestDatabase(ctx, templateDB1)) - - // stop the workers to prevent auto cleaning in background - p.Stop() + require.NoError(t, p.extend(ctx, templateDB1)) testDB1, err := p.GetTestDatabase(ctx, templateDB1.TemplateHash, time.Millisecond) assert.NoError(t, err) // assert that workers are stopped and no new DB showed up _, err = p.GetTestDatabase(ctx, templateDB1.TemplateHash, time.Millisecond) - assert.ErrorIs(t, err, pool.ErrTimeout) + assert.ErrorIs(t, err, ErrTimeout) // return and get the same one assert.NoError(t, p.ReturnTestDatabase(ctx, hash1, testDB1.ID)) diff --git a/pkg/templates/template.go b/pkg/templates/template.go index 1911234..79a3e97 100644 --- a/pkg/templates/template.go +++ b/pkg/templates/template.go @@ -41,7 +41,7 @@ func NewTemplate(hash string, config TemplateConfig) *Template { return t } -func (t *Template) GetConfig(ctx context.Context) TemplateConfig { +func (t *Template) GetConfig(_ context.Context) TemplateConfig { t.mutex.RLock() defer t.mutex.RUnlock() @@ -49,7 +49,7 @@ func (t *Template) GetConfig(ctx context.Context) TemplateConfig { } // GetState locks the template and checks its state. -func (t *Template) GetState(ctx context.Context) TemplateState { +func (t *Template) GetState(_ context.Context) TemplateState { t.mutex.RLock() defer t.mutex.RUnlock() @@ -94,18 +94,18 @@ func (t *Template) WaitUntilFinalized(ctx context.Context, timeout time.Duration // GetStateWithLock gets the current state leaving the template locked. // REMEMBER to unlock it when you no longer need it locked. -func (t *Template) GetStateWithLock(ctx context.Context) (TemplateState, lockedTemplate) { +func (t *Template) GetStateWithLock(_ context.Context) (TemplateState, LockedTemplate) { t.mutex.Lock() - return t.state, lockedTemplate{t: t} + return t.state, LockedTemplate{t: t} } -type lockedTemplate struct { +type LockedTemplate struct { t *Template } // Unlock releases the locked template. -func (l *lockedTemplate) Unlock() { +func (l *LockedTemplate) Unlock() { if l.t != nil { l.t.mutex.Unlock() l.t = nil @@ -113,7 +113,7 @@ func (l *lockedTemplate) Unlock() { } // SetState sets a new state of the locked template (without acquiring the lock again). -func (l lockedTemplate) SetState(ctx context.Context, newState TemplateState) { +func (l LockedTemplate) SetState(_ context.Context, newState TemplateState) { if l.t.state == newState { return } diff --git a/pkg/templates/template_collection.go b/pkg/templates/template_collection.go index 1d4f31f..4769a91 100644 --- a/pkg/templates/template_collection.go +++ b/pkg/templates/template_collection.go @@ -81,7 +81,7 @@ func (tc *Collection) Get(ctx context.Context, hash string) (template *Template, } // RemoveUnsafe removes the template and can be called ONLY IF THE COLLECTION IS LOCKED. -func (tc *Collection) RemoveUnsafe(ctx context.Context, hash string) { +func (tc *Collection) RemoveUnsafe(_ context.Context, hash string) { delete(tc.templates, hash) } diff --git a/pkg/templates/template_test.go b/pkg/templates/template_test.go index ef46708..128c5dd 100644 --- a/pkg/templates/template_test.go +++ b/pkg/templates/template_test.go @@ -2,7 +2,6 @@ package templates_test import ( "context" - "errors" "fmt" "sync" "testing" @@ -28,7 +27,7 @@ func TestTemplateGetSetState(t *testing.T) { assert.Equal(t, templates.TemplateStateDiscarded, state) } -func TestTemplateWaitForReady(t *testing.T) { +func TestForReady(t *testing.T) { ctx := context.Background() goroutineNum := 10 @@ -48,7 +47,7 @@ func TestTemplateWaitForReady(t *testing.T) { timeout := 1 * time.Second state := t1.WaitUntilFinalized(ctx, timeout) if state != templates.TemplateStateFinalized { - errsChan <- errors.New(fmt.Sprintf("expected state %v (finalized), but is %v", templates.TemplateStateFinalized, state)) + errsChan <- fmt.Errorf("expected state %v (finalized), but is %v", templates.TemplateStateFinalized, state) } }() } @@ -58,16 +57,16 @@ func TestTemplateWaitForReady(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - timeout := 3 * time.Millisecond + timeout := 30 * time.Millisecond state := t1.WaitUntilFinalized(ctx, timeout) if state != templates.TemplateStateInit { - errsChan <- errors.New(fmt.Sprintf("expected state %v (init), but is %v", templates.TemplateStateInit, state)) + errsChan <- fmt.Errorf("expected state %v (init), but is %v", templates.TemplateStateInit, state) } }() } // now set state - time.Sleep(5 * time.Millisecond) + time.Sleep(50 * time.Millisecond) t1.SetState(ctx, templates.TemplateStateFinalized) wg.Wait() diff --git a/pkg/util/context.go b/pkg/util/context.go new file mode 100644 index 0000000..00ce454 --- /dev/null +++ b/pkg/util/context.go @@ -0,0 +1,58 @@ +package util + +import ( + "context" + "errors" +) + +type contextKey string + +const ( + CTXKeyUser contextKey = "user" + CTXKeyAccessToken contextKey = "access_token" + CTXKeyRequestID contextKey = "request_id" + CTXKeyDisableLogger contextKey = "disable_logger" + CTXKeyCacheControl contextKey = "cache_control" +) + +// RequestIDFromContext returns the ID of the (HTTP) request, returning an error if it is not present. +func RequestIDFromContext(ctx context.Context) (string, error) { + val := ctx.Value(CTXKeyRequestID) + if val == nil { + return "", errors.New("No request ID present in context") + } + + id, ok := val.(string) + if !ok { + return "", errors.New("Request ID in context is not a string") + } + + return id, nil +} + +// ShouldDisableLogger checks whether the logger instance should be disabled for the provided context. +// `util.LogFromContext` will use this function to check whether it should return a default logger if +// none has been set by our logging middleware before, or fall back to the disabled logger, suppressing +// all output. Use `ctx = util.DisableLogger(ctx, true)` to disable logging for the given context. +func ShouldDisableLogger(ctx context.Context) bool { + s := ctx.Value(CTXKeyDisableLogger) + if s == nil { + return false + } + + shouldDisable, ok := s.(bool) + if !ok { + return false + } + + return shouldDisable +} + +// DisableLogger toggles the indication whether `util.LogFromContext` should return a disabled logger +// for a context if none has been set by our logging middleware before. Whilst the usecase for a disabled +// logger are relatively minimal (we almost always want to have some log output, even if the context +// was not directly derived from a HTTP request), this functionality was provideds so you can switch back +// to the old zerolog behavior if so desired. +func DisableLogger(ctx context.Context, shouldDisable bool) context.Context { + return context.WithValue(ctx, CTXKeyDisableLogger, shouldDisable) +} diff --git a/pkg/util/log.go b/pkg/util/log.go new file mode 100644 index 0000000..c1f4ceb --- /dev/null +++ b/pkg/util/log.go @@ -0,0 +1,42 @@ +package util + +import ( + "context" + + "github.com/labstack/echo/v4" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +// LogFromContext returns a request-specific zerolog instance using the provided context. +// The returned logger will have the request ID as well as some other value predefined. +// If no logger is associated with the context provided, the global zerolog instance +// will be returned instead - this function will _always_ return a valid (enabled) logger. +// Should you ever need to force a disabled logger for a context, use `util.DisableLogger(ctx, true)` +// and pass the context returned to other code/`LogFromContext`. +func LogFromContext(ctx context.Context) *zerolog.Logger { + l := log.Ctx(ctx) + if l.GetLevel() == zerolog.Disabled { + if ShouldDisableLogger(ctx) { + return l + } + l = &log.Logger + } + return l +} + +// LogFromEchoContext returns a request-specific zerolog instance using the echo.Context of the request. +// The returned logger will have the request ID as well as some other value predefined. +func LogFromEchoContext(c echo.Context) *zerolog.Logger { + return LogFromContext(c.Request().Context()) +} + +func LogLevelFromString(s string) zerolog.Level { + l, err := zerolog.ParseLevel(s) + if err != nil { + log.Error().Err(err).Msgf("Failed to parse log level, defaulting to %s", zerolog.DebugLevel) + return zerolog.DebugLevel + } + + return l +} diff --git a/pkg/util/log_test.go b/pkg/util/log_test.go new file mode 100644 index 0000000..c318b74 --- /dev/null +++ b/pkg/util/log_test.go @@ -0,0 +1,20 @@ +package util_test + +import ( + "testing" + + "github.com/allaboutapps/integresql/pkg/util" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +func TestLogLevelFromString(t *testing.T) { + res := util.LogLevelFromString("panic") + assert.Equal(t, zerolog.PanicLevel, res) + + res = util.LogLevelFromString("warn") + assert.Equal(t, zerolog.WarnLevel, res) + + res = util.LogLevelFromString("foo") + assert.Equal(t, zerolog.DebugLevel, res) +} diff --git a/pkg/util/retry.go b/pkg/util/retry.go index 1f629c7..ac5fb18 100644 --- a/pkg/util/retry.go +++ b/pkg/util/retry.go @@ -17,5 +17,5 @@ func Retry(attempts int, sleep time.Duration, f func() error) error { time.Sleep(sleep) } - return fmt.Errorf("failing after %d attempts, lat error: %v", attempts, err) + return fmt.Errorf("failing after %d attempts, lat error: %w", attempts, err) } diff --git a/tests/testclient/client.go b/tests/testclient/client.go index 686c204..503a241 100644 --- a/tests/testclient/client.go +++ b/tests/testclient/client.go @@ -8,6 +8,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -16,6 +17,8 @@ import ( "github.com/allaboutapps/integresql/pkg/manager" "github.com/allaboutapps/integresql/pkg/util" + + // Import postgres driver for database/sql package _ "github.com/lib/pq" ) @@ -124,7 +127,7 @@ func (c *Client) SetupTemplate(ctx context.Context, hash string, init func(conn } return c.FinalizeTemplate(ctx, hash) - } else if err == manager.ErrTemplateAlreadyInitialized { + } else if errors.Is(err, manager.ErrTemplateAlreadyInitialized) { return nil } else { return err