Skip to content

Commit

Permalink
test: use t.Setenv to set env vars in tests (#1940)
Browse files Browse the repository at this point in the history
* test: use `t.Setenv` to set env vars in tests

This commit replaces `os.Setenv` with `t.Setenv` in tests. The
environment variable is automatically restored to its original value
when the test and all its subtests complete.

Reference: https://pkg.go.dev/testing#T.Setenv
Signed-off-by: Eng Zer Jun <[email protected]>

* minor adjustments

Signed-off-by: Dave Henderson <[email protected]>

---------

Signed-off-by: Eng Zer Jun <[email protected]>
Signed-off-by: Dave Henderson <[email protected]>
Co-authored-by: Dave Henderson <[email protected]>
  • Loading branch information
Juneezee and hairyhenderson authored Dec 19, 2023
1 parent e6835bc commit 483af65
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 259 deletions.
96 changes: 48 additions & 48 deletions aws/ec2info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,64 +133,64 @@ func TestNewEc2Info(t *testing.T) {
}

func TestGetRegion(t *testing.T) {
oldReg, ok := os.LookupEnv("AWS_REGION")
if ok {
defer os.Setenv("AWS_REGION", oldReg)
}
oldDefReg, ok := os.LookupEnv("AWS_DEFAULT_REGION")
if ok {
defer os.Setenv("AWS_REGION", oldDefReg)
}

os.Setenv("AWS_REGION", "kalamazoo")
// unset AWS region env vars for clean tests
os.Unsetenv("AWS_REGION")
os.Unsetenv("AWS_DEFAULT_REGION")
region, err := getRegion()
require.NoError(t, err)
assert.Empty(t, region)

os.Setenv("AWS_DEFAULT_REGION", "kalamazoo")
os.Unsetenv("AWS_REGION")
region, err = getRegion()
require.NoError(t, err)
assert.Empty(t, region)
t.Run("with AWS_REGION set", func(t *testing.T) {
t.Setenv("AWS_REGION", "kalamazoo")
region, err := getRegion()
require.NoError(t, err)
assert.Empty(t, region)
})

os.Unsetenv("AWS_DEFAULT_REGION")
metaClient := NewDummyEc2Meta()
region, err = getRegion(metaClient)
require.NoError(t, err)
assert.Equal(t, "unknown", region)
t.Run("with AWS_DEFAULT_REGION set", func(t *testing.T) {
t.Setenv("AWS_DEFAULT_REGION", "kalamazoo")
region, err := getRegion()
require.NoError(t, err)
assert.Empty(t, region)
})

ec2meta := MockEC2Meta(nil, nil, "us-east-1")
t.Run("with no AWS_REGION, AWS_DEFAULT_REGION set", func(t *testing.T) {
metaClient := NewDummyEc2Meta()
region, err := getRegion(metaClient)
require.NoError(t, err)
assert.Equal(t, "unknown", region)
})

region, err = getRegion(ec2meta)
require.NoError(t, err)
assert.Equal(t, "us-east-1", region)
t.Run("infer from EC2 metadata", func(t *testing.T) {
ec2meta := MockEC2Meta(nil, nil, "us-east-1")
region, err := getRegion(ec2meta)
require.NoError(t, err)
assert.Equal(t, "us-east-1", region)
})
}

func TestGetClientOptions(t *testing.T) {
oldVar, ok := os.LookupEnv("AWS_TIMEOUT")
if ok {
defer os.Setenv("AWS_TIMEOUT", oldVar)
}

co := GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 500 * time.Millisecond}, co)

os.Setenv("AWS_TIMEOUT", "42")
// reset the Once
coInit = sync.Once{}
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)

os.Setenv("AWS_TIMEOUT", "123")
// without resetting the Once, expect to be reused
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)

os.Setenv("AWS_TIMEOUT", "foo")
// reset the Once
coInit = sync.Once{}
assert.Panics(t, func() {
GetClientOptions()
t.Run("valid AWS_TIMEOUT, first call", func(t *testing.T) {
t.Setenv("AWS_TIMEOUT", "42")
// reset the Once
coInit = sync.Once{}
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
})

t.Run("valid AWS_TIMEOUT, non-first call", func(t *testing.T) {
t.Setenv("AWS_TIMEOUT", "123")
// without resetting the Once, expect to be reused
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
})

t.Run("invalid AWS_TIMEOUT", func(t *testing.T) {
t.Setenv("AWS_TIMEOUT", "foo")
// reset the Once
coInit = sync.Once{}
assert.Panics(t, func() {
GetClientOptions()
})
})
}
8 changes: 3 additions & 5 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestEnvMapifiesEnvironment(t *testing.T) {
func TestEnvGetsUpdatedEnvironment(t *testing.T) {
c := &tmplctx{}
assert.Empty(t, c.Env()["FOO"])
require.NoError(t, os.Setenv("FOO", "foo"))
t.Setenv("FOO", "foo")
assert.Equal(t, c.Env()["FOO"], "foo")
}

Expand All @@ -42,17 +42,15 @@ func TestCreateContext(t *testing.T) {
".": {URL: ub},
},
}
os.Setenv("foo", "foo: bar")
defer os.Unsetenv("foo")
t.Setenv("foo", "foo: bar")
c, err = createTmplContext(ctx, []string{"foo"}, d)
require.NoError(t, err)
assert.IsType(t, &tmplctx{}, c)
tctx := c.(*tmplctx)
ds := ((*tctx)["foo"]).(map[string]interface{})
assert.Equal(t, "bar", ds["foo"])

os.Setenv("bar", "bar: baz")
defer os.Unsetenv("bar")
t.Setenv("bar", "bar: baz")
c, err = createTmplContext(ctx, []string{"."}, d)
require.NoError(t, err)
assert.IsType(t, map[string]interface{}{}, c)
Expand Down
69 changes: 33 additions & 36 deletions data/datasource_blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"net/http/httptest"
"net/url"
"os"
"testing"

"github.com/johannesboyne/gofakes3"
Expand Down Expand Up @@ -64,50 +63,48 @@ func TestReadBlob(t *testing.T) {
ts, u := setupTestBucket(t)
defer ts.Close()

os.Setenv("AWS_ANON", "true")
defer os.Unsetenv("AWS_ANON")
t.Run("no authentication", func(t *testing.T) {
t.Setenv("AWS_ANON", "true")

d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil)
require.NoError(t, err)
d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil)
require.NoError(t, err)

var expected interface{}
expected = "hello"
out, err := d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)
expected := "hello"
out, err := d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)
})

os.Unsetenv("AWS_ANON")
t.Run("with authentication", func(t *testing.T) {
t.Setenv("AWS_ACCESS_KEY_ID", "fake")
t.Setenv("AWS_SECRET_ACCESS_KEY", "fake")
t.Setenv("AWS_S3_ENDPOINT", u.Host)

os.Setenv("AWS_ACCESS_KEY_ID", "fake")
os.Setenv("AWS_SECRET_ACCESS_KEY", "fake")
defer os.Unsetenv("AWS_ACCESS_KEY_ID")
defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Setenv("AWS_S3_ENDPOINT", u.Host)
defer os.Unsetenv("AWS_S3_ENDPOINT")
d, err := NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)

d, err = NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)
var expected interface{}
expected = map[string]interface{}{"value": "goodbye world"}
out, err := d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)

expected = map[string]interface{}{"value": "goodbye world"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)
d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)

d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)
expected = []interface{}{"dir1/", "file1", "file2", "file3"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)

expected = []interface{}{"dir1/", "file1", "file2", "file3"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)

d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)
d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)

expected = []interface{}{"file1", "file2"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)
expected = []interface{}{"file1", "file2"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)
})
}

func TestBlobURL(t *testing.T) {
Expand Down
7 changes: 2 additions & 5 deletions data/datasource_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package data
import (
"context"
"net/url"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -19,10 +18,8 @@ func TestReadEnv(t *testing.T) {
ctx := context.Background()

content := []byte(`hello world`)
os.Setenv("HELLO_WORLD", "hello world")
defer os.Unsetenv("HELLO_WORLD")
os.Setenv("HELLO_UNIVERSE", "hello universe")
defer os.Unsetenv("HELLO_UNIVERSE")
t.Setenv("HELLO_WORLD", "hello world")
t.Setenv("HELLO_UNIVERSE", "hello universe")

source := &Source{Alias: "foo", URL: mustParseURL("env:HELLO_WORLD")}

Expand Down
45 changes: 22 additions & 23 deletions data/datasource_git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,13 @@ func TestGitAuth(t *testing.T) {
assert.NilError(t, err)
assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a)

os.Setenv("GIT_HTTP_PASSWORD", "swordfish")
defer os.Unsetenv("GIT_HTTP_PASSWORD")
t.Setenv("GIT_HTTP_PASSWORD", "swordfish")
a, err = g.auth(mustParseURL("git+https://[email protected]/foo"))
assert.NilError(t, err)
assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a)
os.Unsetenv("GIT_HTTP_PASSWORD")

os.Setenv("GIT_HTTP_TOKEN", "mytoken")
defer os.Unsetenv("GIT_HTTP_TOKEN")
t.Setenv("GIT_HTTP_TOKEN", "mytoken")
a, err = g.auth(mustParseURL("git+https://[email protected]/foo"))
assert.NilError(t, err)
assert.DeepEqual(t, &http.TokenAuth{Token: "mytoken"}, a)
Expand All @@ -508,25 +506,26 @@ func TestGitAuth(t *testing.T) {
assert.Equal(t, "git", sa.User)
}

key := string(testdata.PEMBytes["ed25519"])
os.Setenv("GIT_SSH_KEY", key)
defer os.Unsetenv("GIT_SSH_KEY")
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok := a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
os.Unsetenv("GIT_SSH_KEY")

key = base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"])
os.Setenv("GIT_SSH_KEY", key)
defer os.Unsetenv("GIT_SSH_KEY")
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok = a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
os.Unsetenv("GIT_SSH_KEY")
t.Run("plain string ed25519", func(t *testing.T) {
key := string(testdata.PEMBytes["ed25519"])
t.Setenv("GIT_SSH_KEY", key)
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok := a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
})

t.Run("base64 ed25519", func(t *testing.T) {
key := base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"])
t.Setenv("GIT_SSH_KEY", key)
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok := a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
os.Unsetenv("GIT_SSH_KEY")
})
}

func TestRefFromURL(t *testing.T) {
Expand Down
Loading

0 comments on commit 483af65

Please sign in to comment.