diff --git a/aws/ec2info_test.go b/aws/ec2info_test.go index 2a9efefe6..497f50d17 100644 --- a/aws/ec2info_test.go +++ b/aws/ec2info_test.go @@ -133,64 +133,64 @@ func TestNewEc2Info(t *testing.T) { } func TestGetRegion(t *testing.T) { - oldReg, ok := os.LookupEnv("AWS_REGION") - if ok { - defer os.Setenv("AWS_REGION", oldReg) - } - oldDefReg, ok := os.LookupEnv("AWS_DEFAULT_REGION") - if ok { - defer os.Setenv("AWS_REGION", oldDefReg) - } - - os.Setenv("AWS_REGION", "kalamazoo") + // unset AWS region env vars for clean tests + os.Unsetenv("AWS_REGION") os.Unsetenv("AWS_DEFAULT_REGION") - region, err := getRegion() - require.NoError(t, err) - assert.Empty(t, region) - os.Setenv("AWS_DEFAULT_REGION", "kalamazoo") - os.Unsetenv("AWS_REGION") - region, err = getRegion() - require.NoError(t, err) - assert.Empty(t, region) + t.Run("with AWS_REGION set", func(t *testing.T) { + t.Setenv("AWS_REGION", "kalamazoo") + region, err := getRegion() + require.NoError(t, err) + assert.Empty(t, region) + }) - os.Unsetenv("AWS_DEFAULT_REGION") - metaClient := NewDummyEc2Meta() - region, err = getRegion(metaClient) - require.NoError(t, err) - assert.Equal(t, "unknown", region) + t.Run("with AWS_DEFAULT_REGION set", func(t *testing.T) { + t.Setenv("AWS_DEFAULT_REGION", "kalamazoo") + region, err := getRegion() + require.NoError(t, err) + assert.Empty(t, region) + }) - ec2meta := MockEC2Meta(nil, nil, "us-east-1") + t.Run("with no AWS_REGION, AWS_DEFAULT_REGION set", func(t *testing.T) { + metaClient := NewDummyEc2Meta() + region, err := getRegion(metaClient) + require.NoError(t, err) + assert.Equal(t, "unknown", region) + }) - region, err = getRegion(ec2meta) - require.NoError(t, err) - assert.Equal(t, "us-east-1", region) + t.Run("infer from EC2 metadata", func(t *testing.T) { + ec2meta := MockEC2Meta(nil, nil, "us-east-1") + region, err := getRegion(ec2meta) + require.NoError(t, err) + assert.Equal(t, "us-east-1", region) + }) } func TestGetClientOptions(t *testing.T) { - oldVar, ok := os.LookupEnv("AWS_TIMEOUT") - if ok { - defer os.Setenv("AWS_TIMEOUT", oldVar) - } - co := GetClientOptions() assert.Equal(t, ClientOptions{Timeout: 500 * time.Millisecond}, co) - os.Setenv("AWS_TIMEOUT", "42") - // reset the Once - coInit = sync.Once{} - co = GetClientOptions() - assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co) - - os.Setenv("AWS_TIMEOUT", "123") - // without resetting the Once, expect to be reused - co = GetClientOptions() - assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co) - - os.Setenv("AWS_TIMEOUT", "foo") - // reset the Once - coInit = sync.Once{} - assert.Panics(t, func() { - GetClientOptions() + t.Run("valid AWS_TIMEOUT, first call", func(t *testing.T) { + t.Setenv("AWS_TIMEOUT", "42") + // reset the Once + coInit = sync.Once{} + co = GetClientOptions() + assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co) + }) + + t.Run("valid AWS_TIMEOUT, non-first call", func(t *testing.T) { + t.Setenv("AWS_TIMEOUT", "123") + // without resetting the Once, expect to be reused + co = GetClientOptions() + assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co) + }) + + t.Run("invalid AWS_TIMEOUT", func(t *testing.T) { + t.Setenv("AWS_TIMEOUT", "foo") + // reset the Once + coInit = sync.Once{} + assert.Panics(t, func() { + GetClientOptions() + }) }) } diff --git a/context_test.go b/context_test.go index 8fd4da71f..aa3a44687 100644 --- a/context_test.go +++ b/context_test.go @@ -21,7 +21,7 @@ func TestEnvMapifiesEnvironment(t *testing.T) { func TestEnvGetsUpdatedEnvironment(t *testing.T) { c := &tmplctx{} assert.Empty(t, c.Env()["FOO"]) - require.NoError(t, os.Setenv("FOO", "foo")) + t.Setenv("FOO", "foo") assert.Equal(t, c.Env()["FOO"], "foo") } @@ -42,8 +42,7 @@ func TestCreateContext(t *testing.T) { ".": {URL: ub}, }, } - os.Setenv("foo", "foo: bar") - defer os.Unsetenv("foo") + t.Setenv("foo", "foo: bar") c, err = createTmplContext(ctx, []string{"foo"}, d) require.NoError(t, err) assert.IsType(t, &tmplctx{}, c) @@ -51,8 +50,7 @@ func TestCreateContext(t *testing.T) { ds := ((*tctx)["foo"]).(map[string]interface{}) assert.Equal(t, "bar", ds["foo"]) - os.Setenv("bar", "bar: baz") - defer os.Unsetenv("bar") + t.Setenv("bar", "bar: baz") c, err = createTmplContext(ctx, []string{"."}, d) require.NoError(t, err) assert.IsType(t, map[string]interface{}{}, c) diff --git a/data/datasource_blob_test.go b/data/datasource_blob_test.go index c79780d0b..6be8ea003 100644 --- a/data/datasource_blob_test.go +++ b/data/datasource_blob_test.go @@ -5,7 +5,6 @@ import ( "context" "net/http/httptest" "net/url" - "os" "testing" "github.com/johannesboyne/gofakes3" @@ -64,50 +63,48 @@ func TestReadBlob(t *testing.T) { ts, u := setupTestBucket(t) defer ts.Close() - os.Setenv("AWS_ANON", "true") - defer os.Unsetenv("AWS_ANON") + t.Run("no authentication", func(t *testing.T) { + t.Setenv("AWS_ANON", "true") - d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil) - require.NoError(t, err) + d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil) + require.NoError(t, err) - var expected interface{} - expected = "hello" - out, err := d.Datasource("data") - require.NoError(t, err) - assert.Equal(t, expected, out) + expected := "hello" + out, err := d.Datasource("data") + require.NoError(t, err) + assert.Equal(t, expected, out) + }) - os.Unsetenv("AWS_ANON") + t.Run("with authentication", func(t *testing.T) { + t.Setenv("AWS_ACCESS_KEY_ID", "fake") + t.Setenv("AWS_SECRET_ACCESS_KEY", "fake") + t.Setenv("AWS_S3_ENDPOINT", u.Host) - os.Setenv("AWS_ACCESS_KEY_ID", "fake") - os.Setenv("AWS_SECRET_ACCESS_KEY", "fake") - defer os.Unsetenv("AWS_ACCESS_KEY_ID") - defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") - os.Setenv("AWS_S3_ENDPOINT", u.Host) - defer os.Unsetenv("AWS_S3_ENDPOINT") + d, err := NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil) + require.NoError(t, err) - d, err = NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil) - require.NoError(t, err) + var expected interface{} + expected = map[string]interface{}{"value": "goodbye world"} + out, err := d.Datasource("data") + require.NoError(t, err) + assert.Equal(t, expected, out) - expected = map[string]interface{}{"value": "goodbye world"} - out, err = d.Datasource("data") - require.NoError(t, err) - assert.Equal(t, expected, out) + d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil) + require.NoError(t, err) - d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil) - require.NoError(t, err) + expected = []interface{}{"dir1/", "file1", "file2", "file3"} + out, err = d.Datasource("data") + require.NoError(t, err) + assert.EqualValues(t, expected, out) - expected = []interface{}{"dir1/", "file1", "file2", "file3"} - out, err = d.Datasource("data") - require.NoError(t, err) - assert.EqualValues(t, expected, out) - - d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil) - require.NoError(t, err) + d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil) + require.NoError(t, err) - expected = []interface{}{"file1", "file2"} - out, err = d.Datasource("data") - require.NoError(t, err) - assert.EqualValues(t, expected, out) + expected = []interface{}{"file1", "file2"} + out, err = d.Datasource("data") + require.NoError(t, err) + assert.EqualValues(t, expected, out) + }) } func TestBlobURL(t *testing.T) { diff --git a/data/datasource_env_test.go b/data/datasource_env_test.go index 7a10b3473..6512578c1 100644 --- a/data/datasource_env_test.go +++ b/data/datasource_env_test.go @@ -3,7 +3,6 @@ package data import ( "context" "net/url" - "os" "testing" "github.com/stretchr/testify/assert" @@ -19,10 +18,8 @@ func TestReadEnv(t *testing.T) { ctx := context.Background() content := []byte(`hello world`) - os.Setenv("HELLO_WORLD", "hello world") - defer os.Unsetenv("HELLO_WORLD") - os.Setenv("HELLO_UNIVERSE", "hello universe") - defer os.Unsetenv("HELLO_UNIVERSE") + t.Setenv("HELLO_WORLD", "hello world") + t.Setenv("HELLO_UNIVERSE", "hello universe") source := &Source{Alias: "foo", URL: mustParseURL("env:HELLO_WORLD")} diff --git a/data/datasource_git_test.go b/data/datasource_git_test.go index 3f74a620a..3b187ecce 100644 --- a/data/datasource_git_test.go +++ b/data/datasource_git_test.go @@ -484,15 +484,13 @@ func TestGitAuth(t *testing.T) { assert.NilError(t, err) assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a) - os.Setenv("GIT_HTTP_PASSWORD", "swordfish") - defer os.Unsetenv("GIT_HTTP_PASSWORD") + t.Setenv("GIT_HTTP_PASSWORD", "swordfish") a, err = g.auth(mustParseURL("git+https://user@example.com/foo")) assert.NilError(t, err) assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a) os.Unsetenv("GIT_HTTP_PASSWORD") - os.Setenv("GIT_HTTP_TOKEN", "mytoken") - defer os.Unsetenv("GIT_HTTP_TOKEN") + t.Setenv("GIT_HTTP_TOKEN", "mytoken") a, err = g.auth(mustParseURL("git+https://user@example.com/foo")) assert.NilError(t, err) assert.DeepEqual(t, &http.TokenAuth{Token: "mytoken"}, a) @@ -508,25 +506,26 @@ func TestGitAuth(t *testing.T) { assert.Equal(t, "git", sa.User) } - key := string(testdata.PEMBytes["ed25519"]) - os.Setenv("GIT_SSH_KEY", key) - defer os.Unsetenv("GIT_SSH_KEY") - a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo")) - assert.NilError(t, err) - ka, ok := a.(*ssh.PublicKeys) - assert.Equal(t, true, ok) - assert.Equal(t, "git", ka.User) - os.Unsetenv("GIT_SSH_KEY") - - key = base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"]) - os.Setenv("GIT_SSH_KEY", key) - defer os.Unsetenv("GIT_SSH_KEY") - a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo")) - assert.NilError(t, err) - ka, ok = a.(*ssh.PublicKeys) - assert.Equal(t, true, ok) - assert.Equal(t, "git", ka.User) - os.Unsetenv("GIT_SSH_KEY") + t.Run("plain string ed25519", func(t *testing.T) { + key := string(testdata.PEMBytes["ed25519"]) + t.Setenv("GIT_SSH_KEY", key) + a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo")) + assert.NilError(t, err) + ka, ok := a.(*ssh.PublicKeys) + assert.Equal(t, true, ok) + assert.Equal(t, "git", ka.User) + }) + + t.Run("base64 ed25519", func(t *testing.T) { + key := base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"]) + t.Setenv("GIT_SSH_KEY", key) + a, err = g.auth(mustParseURL("git+ssh://git@example.com/foo")) + assert.NilError(t, err) + ka, ok := a.(*ssh.PublicKeys) + assert.Equal(t, true, ok) + assert.Equal(t, "git", ka.User) + os.Unsetenv("GIT_SSH_KEY") + }) } func TestRefFromURL(t *testing.T) { diff --git a/internal/cmd/config_test.go b/internal/cmd/config_test.go index f4c032c87..0cac99838 100644 --- a/internal/cmd/config_test.go +++ b/internal/cmd/config_test.go @@ -6,7 +6,6 @@ import ( "fmt" "io/fs" "net/url" - "os" "testing" "testing/fstest" "time" @@ -167,32 +166,38 @@ func TestPickConfigFile(t *testing.T) { cmd := &cobra.Command{} cmd.Flags().String("config", defaultConfigFile, "foo") - cf, req := pickConfigFile(cmd) - assert.False(t, req) - assert.Equal(t, defaultConfigFile, cf) - - os.Setenv("GOMPLATE_CONFIG", "foo.yaml") - defer os.Unsetenv("GOMPLATE_CONFIG") - cf, req = pickConfigFile(cmd) - assert.True(t, req) - assert.Equal(t, "foo.yaml", cf) - - cmd.ParseFlags([]string{"--config", "config.file"}) - cf, req = pickConfigFile(cmd) - assert.True(t, req) - assert.Equal(t, "config.file", cf) - - os.Setenv("GOMPLATE_CONFIG", "ignored.yaml") - cf, req = pickConfigFile(cmd) - assert.True(t, req) - assert.Equal(t, "config.file", cf) + t.Run("default", func(t *testing.T) { + cf, req := pickConfigFile(cmd) + assert.False(t, req) + assert.Equal(t, defaultConfigFile, cf) + }) + + t.Run("GOMPLATE_CONFIG env var", func(t *testing.T) { + t.Setenv("GOMPLATE_CONFIG", "foo.yaml") + cf, req := pickConfigFile(cmd) + assert.True(t, req) + assert.Equal(t, "foo.yaml", cf) + }) + + t.Run("--config flag", func(t *testing.T) { + cmd.ParseFlags([]string{"--config", "config.file"}) + cf, req := pickConfigFile(cmd) + assert.True(t, req) + assert.Equal(t, "config.file", cf) + + t.Setenv("GOMPLATE_CONFIG", "ignored.yaml") + cf, req = pickConfigFile(cmd) + assert.True(t, req) + assert.Equal(t, "config.file", cf) + }) } func TestApplyEnvVars(t *testing.T) { - os.Setenv("GOMPLATE_PLUGIN_TIMEOUT", "bogus") - _, err := applyEnvVars(context.Background(), &config.Config{}) - os.Unsetenv("GOMPLATE_PLUGIN_TIMEOUT") - assert.Error(t, err) + t.Run("invalid GOMPLATE_PLUGIN_TIMEOUT", func(t *testing.T) { + t.Setenv("GOMPLATE_PLUGIN_TIMEOUT", "bogus") + _, err := applyEnvVars(context.Background(), &config.Config{}) + assert.Error(t, err) + }) data := []struct { input, expected *config.Config @@ -274,10 +279,9 @@ func TestApplyEnvVars(t *testing.T) { for i, d := range data { d := d t.Run(fmt.Sprintf("applyEnvVars_%s_%s/%d", d.env, d.value, i), func(t *testing.T) { - os.Setenv(d.env, d.value) + t.Setenv(d.env, d.value) actual, err := applyEnvVars(context.Background(), d.input) - os.Unsetenv(d.env) require.NoError(t, err) assert.EqualValues(t, d.expected, actual) }) diff --git a/internal/cmd/logger_test.go b/internal/cmd/logger_test.go index 72f8f292e..ea595c1e4 100644 --- a/internal/cmd/logger_test.go +++ b/internal/cmd/logger_test.go @@ -14,13 +14,12 @@ import ( func TestLogFormat(t *testing.T) { os.Unsetenv("GOMPLATE_LOG_FORMAT") - defer os.Unsetenv("GOMPLATE_LOG_FORMAT") assert.Equal(t, "json", logFormat(nil)) // os.Stdout isn't a terminal when this runs as a unit test... assert.Equal(t, "json", logFormat(os.Stdout)) - os.Setenv("GOMPLATE_LOG_FORMAT", "simple") + t.Setenv("GOMPLATE_LOG_FORMAT", "simple") assert.Equal(t, "simple", logFormat(os.Stdout)) assert.Equal(t, "simple", logFormat(&bytes.Buffer{})) } diff --git a/libkv/consul_test.go b/libkv/consul_test.go index 74d78d01e..5f128aa73 100644 --- a/libkv/consul_test.go +++ b/libkv/consul_test.go @@ -14,68 +14,83 @@ import ( ) func TestConsulURL(t *testing.T) { - defer os.Unsetenv("CONSUL_HTTP_SSL") - os.Setenv("CONSUL_HTTP_SSL", "true") - - u, _ := url.Parse("consul://") - expected := &url.URL{Host: "localhost:8500", Scheme: "https"} - actual, err := consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) - - u, _ = url.Parse("consul+http://myconsul.server") - expected = &url.URL{Host: "myconsul.server", Scheme: "http"} - actual, err = consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) - - os.Setenv("CONSUL_HTTP_SSL", "false") - u, _ = url.Parse("consul+https://myconsul.server:1234") - expected = &url.URL{Host: "myconsul.server:1234", Scheme: "https"} - actual, err = consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) - os.Unsetenv("CONSUL_HTTP_SSL") - u, _ = url.Parse("consul://myconsul.server:2345") - expected = &url.URL{Host: "myconsul.server:2345", Scheme: "http"} - actual, err = consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) - - u, _ = url.Parse("consul://myconsul.server:3456/foo/bar/baz") - expected = &url.URL{Host: "myconsul.server:3456", Scheme: "http"} - actual, err = consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) - - defer os.Unsetenv("CONSUL_HTTP_ADDR") - os.Setenv("CONSUL_HTTP_ADDR", "https://foo:8500") - - // given URL takes precedence over env var - expected = &url.URL{Host: "myconsul.server:3456", Scheme: "http"} - actual, err = consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) - - u, _ = url.Parse("consul://") - - defer os.Unsetenv("CONSUL_HTTP_SSL") - os.Setenv("CONSUL_HTTP_SSL", "true") - - // TLS enabled, HTTP_ADDR is set, URL has no host and ambiguous scheme - expected = &url.URL{Host: "foo:8500", Scheme: "https"} - actual, err = consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) - - defer os.Unsetenv("CONSUL_HTTP_ADDR") - os.Setenv("CONSUL_HTTP_ADDR", "localhost:8501") - expected = &url.URL{Host: "localhost:8501", Scheme: "https"} - actual, err = consulURL(u) - require.NoError(t, err) - assert.Equal(t, expected, actual) + t.Run("consul scheme, CONSUL_HTTP_SSL set to true", func(t *testing.T) { + t.Setenv("CONSUL_HTTP_SSL", "true") + + u, _ := url.Parse("consul://") + expected := &url.URL{Host: "localhost:8500", Scheme: "https"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("consul+http scheme", func(t *testing.T) { + u, _ := url.Parse("consul+http://myconsul.server") + expected := &url.URL{Host: "myconsul.server", Scheme: "http"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("consul+https scheme, CONSUL_HTTP_SSL set to false", func(t *testing.T) { + t.Setenv("CONSUL_HTTP_SSL", "false") + + u, _ := url.Parse("consul+https://myconsul.server:1234") + expected := &url.URL{Host: "myconsul.server:1234", Scheme: "https"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("consul scheme, CONSUL_HTTP_SSL unset", func(t *testing.T) { + u, _ := url.Parse("consul://myconsul.server:2345") + expected := &url.URL{Host: "myconsul.server:2345", Scheme: "http"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("consul scheme, ignore path", func(t *testing.T) { + u, _ := url.Parse("consul://myconsul.server:3456/foo/bar/baz") + expected := &url.URL{Host: "myconsul.server:3456", Scheme: "http"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("given URL takes precedence over env var", func(t *testing.T) { + t.Setenv("CONSUL_HTTP_ADDR", "https://foo:8500") + + u, _ := url.Parse("consul://myconsul.server:3456/foo/bar/baz") + expected := &url.URL{Host: "myconsul.server:3456", Scheme: "http"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("TLS enabled, HTTP_ADDR is set, URL has no host and ambiguous scheme", func(t *testing.T) { + t.Setenv("CONSUL_HTTP_ADDR", "https://foo:8500") + t.Setenv("CONSUL_HTTP_SSL", "true") + + u, _ := url.Parse("consul://") + expected := &url.URL{Host: "foo:8500", Scheme: "https"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("TLS enabled, HTTP_ADDR is set without scheme, URL has no host and ambiguous scheme", func(t *testing.T) { + t.Setenv("CONSUL_HTTP_ADDR", "localhost:8501") + t.Setenv("CONSUL_HTTP_SSL", "true") + + u, _ := url.Parse("consul://") + expected := &url.URL{Host: "localhost:8501", Scheme: "https"} + actual, err := consulURL(u) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) } func TestConsulAddrFromEnv(t *testing.T) { @@ -106,55 +121,56 @@ func TestSetupTLS(t *testing.T) { KeyFile: "keyfile", } - defer os.Unsetenv("CONSUL_TLS_SERVER_NAME") - defer os.Unsetenv("CONSUL_CACERT") - defer os.Unsetenv("CONSUL_CAPATH") - defer os.Unsetenv("CONSUL_CLIENT_CERT") - defer os.Unsetenv("CONSUL_CLIENT_KEY") - os.Setenv("CONSUL_TLS_SERVER_NAME", expected.Address) - os.Setenv("CONSUL_CACERT", expected.CAFile) - os.Setenv("CONSUL_CAPATH", expected.CAPath) - os.Setenv("CONSUL_CLIENT_CERT", expected.CertFile) - os.Setenv("CONSUL_CLIENT_KEY", expected.KeyFile) - - assert.Equal(t, expected, setupTLS()) + t.Setenv("CONSUL_TLS_SERVER_NAME", expected.Address) + t.Setenv("CONSUL_CACERT", expected.CAFile) + t.Setenv("CONSUL_CAPATH", expected.CAPath) + t.Setenv("CONSUL_CLIENT_CERT", expected.CertFile) + t.Setenv("CONSUL_CLIENT_KEY", expected.KeyFile) - expected.InsecureSkipVerify = false - defer os.Unsetenv("CONSUL_HTTP_SSL_VERIFY") - os.Setenv("CONSUL_HTTP_SSL_VERIFY", "true") assert.Equal(t, expected, setupTLS()) - expected.InsecureSkipVerify = true - os.Setenv("CONSUL_HTTP_SSL_VERIFY", "false") - assert.Equal(t, expected, setupTLS()) + t.Run("CONSUL_HTTP_SSL_VERIFY is true", func(t *testing.T) { + expected.InsecureSkipVerify = false + t.Setenv("CONSUL_HTTP_SSL_VERIFY", "true") + assert.Equal(t, expected, setupTLS()) + }) + + t.Run("CONSUL_HTTP_SSL_VERIFY is false", func(t *testing.T) { + expected.InsecureSkipVerify = true + t.Setenv("CONSUL_HTTP_SSL_VERIFY", "false") + assert.Equal(t, expected, setupTLS()) + }) } func TestConsulConfig(t *testing.T) { - expectedConfig := &store.Config{} - - actualConfig, err := consulConfig(false) - require.NoError(t, err) - - assert.Equal(t, expectedConfig, actualConfig) - - defer os.Unsetenv("CONSUL_TIMEOUT") - os.Setenv("CONSUL_TIMEOUT", "10") - expectedConfig = &store.Config{ - ConnectionTimeout: 10 * time.Second, - } - - actualConfig, err = consulConfig(false) - require.NoError(t, err) - assert.Equal(t, expectedConfig, actualConfig) - - os.Unsetenv("CONSUL_TIMEOUT") - expectedConfig = &store.Config{ - TLS: &tls.Config{MinVersion: tls.VersionTLS13}, - } - - actualConfig, err = consulConfig(true) - require.NoError(t, err) - assert.NotNil(t, actualConfig.TLS) - actualConfig.TLS = &tls.Config{MinVersion: tls.VersionTLS13} - assert.Equal(t, expectedConfig, actualConfig) + t.Run("default ", func(t *testing.T) { + expectedConfig := &store.Config{} + + actualConfig, err := consulConfig(false) + require.NoError(t, err) + + assert.Equal(t, expectedConfig, actualConfig) + }) + + t.Run("with CONSUL_TIMEOUT", func(t *testing.T) { + t.Setenv("CONSUL_TIMEOUT", "10") + expectedConfig := &store.Config{ + ConnectionTimeout: 10 * time.Second, + } + + actualConfig, err := consulConfig(false) + require.NoError(t, err) + assert.Equal(t, expectedConfig, actualConfig) + }) + + t.Run("with TLS", func(t *testing.T) { + expectedConfig := &store.Config{ + TLS: &tls.Config{MinVersion: tls.VersionTLS13}, + } + actualConfig, err := consulConfig(true) + require.NoError(t, err) + assert.NotNil(t, actualConfig.TLS) + actualConfig.TLS = &tls.Config{MinVersion: tls.VersionTLS13} + assert.Equal(t, expectedConfig, actualConfig) + }) } diff --git a/render_test.go b/render_test.go index 21b0be744..684367faf 100644 --- a/render_test.go +++ b/render_test.go @@ -37,8 +37,7 @@ func TestRenderTemplate(t *testing.T) { hu, _ := url.Parse("stdin:") wu, _ := url.Parse("env:WORLD") - os.Setenv("WORLD", "world") - defer os.Unsetenv("WORLD") + t.Setenv("WORLD", "world") tr = NewRenderer(Options{ Context: map[string]Datasource{ diff --git a/vault/auth_test.go b/vault/auth_test.go index 552db5720..fa8f29ccc 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -1,7 +1,6 @@ package vault import ( - "os" "testing" "github.com/stretchr/testify/assert" @@ -11,8 +10,7 @@ import ( func TestLogin(t *testing.T) { server, v := MockServer(404, "Not Found") defer server.Close() - os.Setenv("VAULT_TOKEN", "foo") - defer os.Unsetenv("VAULT_TOKEN") + t.Setenv("VAULT_TOKEN", "foo") v.Login() assert.Equal(t, "foo", v.client.Token()) } @@ -20,8 +18,7 @@ func TestLogin(t *testing.T) { func TestTokenLogin(t *testing.T) { server, v := MockServer(404, "Not Found") defer server.Close() - os.Setenv("VAULT_TOKEN", "foo") - defer os.Unsetenv("VAULT_TOKEN") + t.Setenv("VAULT_TOKEN", "foo") token, err := v.TokenLogin() require.NoError(t, err) diff --git a/vault/vault_test.go b/vault/vault_test.go index 754e2d4c8..1cc5e31d6 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -15,8 +15,7 @@ func TestNew(t *testing.T) { require.NoError(t, err) assert.Equal(t, "https://127.0.0.1:8200", v.client.Address()) - os.Setenv("VAULT_ADDR", "http://example.com:1234") - defer os.Unsetenv("VAULT_ADDR") + t.Setenv("VAULT_ADDR", "http://example.com:1234") v, err = New(nil) require.NoError(t, err) assert.Equal(t, "http://example.com:1234", v.client.Address())