Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: use t.Setenv to set env vars in tests #1940

Merged
merged 2 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 48 additions & 48 deletions aws/ec2info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,64 +133,64 @@ func TestNewEc2Info(t *testing.T) {
}

func TestGetRegion(t *testing.T) {
oldReg, ok := os.LookupEnv("AWS_REGION")
if ok {
defer os.Setenv("AWS_REGION", oldReg)
}
oldDefReg, ok := os.LookupEnv("AWS_DEFAULT_REGION")
if ok {
defer os.Setenv("AWS_REGION", oldDefReg)
}

os.Setenv("AWS_REGION", "kalamazoo")
// unset AWS region env vars for clean tests
os.Unsetenv("AWS_REGION")
os.Unsetenv("AWS_DEFAULT_REGION")
region, err := getRegion()
require.NoError(t, err)
assert.Empty(t, region)

os.Setenv("AWS_DEFAULT_REGION", "kalamazoo")
os.Unsetenv("AWS_REGION")
region, err = getRegion()
require.NoError(t, err)
assert.Empty(t, region)
t.Run("with AWS_REGION set", func(t *testing.T) {
t.Setenv("AWS_REGION", "kalamazoo")
region, err := getRegion()
require.NoError(t, err)
assert.Empty(t, region)
})

os.Unsetenv("AWS_DEFAULT_REGION")
metaClient := NewDummyEc2Meta()
region, err = getRegion(metaClient)
require.NoError(t, err)
assert.Equal(t, "unknown", region)
t.Run("with AWS_DEFAULT_REGION set", func(t *testing.T) {
t.Setenv("AWS_DEFAULT_REGION", "kalamazoo")
region, err := getRegion()
require.NoError(t, err)
assert.Empty(t, region)
})

ec2meta := MockEC2Meta(nil, nil, "us-east-1")
t.Run("with no AWS_REGION, AWS_DEFAULT_REGION set", func(t *testing.T) {
metaClient := NewDummyEc2Meta()
region, err := getRegion(metaClient)
require.NoError(t, err)
assert.Equal(t, "unknown", region)
})

region, err = getRegion(ec2meta)
require.NoError(t, err)
assert.Equal(t, "us-east-1", region)
t.Run("infer from EC2 metadata", func(t *testing.T) {
ec2meta := MockEC2Meta(nil, nil, "us-east-1")
region, err := getRegion(ec2meta)
require.NoError(t, err)
assert.Equal(t, "us-east-1", region)
})
}

func TestGetClientOptions(t *testing.T) {
oldVar, ok := os.LookupEnv("AWS_TIMEOUT")
if ok {
defer os.Setenv("AWS_TIMEOUT", oldVar)
}

co := GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 500 * time.Millisecond}, co)

os.Setenv("AWS_TIMEOUT", "42")
// reset the Once
coInit = sync.Once{}
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)

os.Setenv("AWS_TIMEOUT", "123")
// without resetting the Once, expect to be reused
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)

os.Setenv("AWS_TIMEOUT", "foo")
// reset the Once
coInit = sync.Once{}
assert.Panics(t, func() {
GetClientOptions()
t.Run("valid AWS_TIMEOUT, first call", func(t *testing.T) {
t.Setenv("AWS_TIMEOUT", "42")
// reset the Once
coInit = sync.Once{}
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
})

t.Run("valid AWS_TIMEOUT, non-first call", func(t *testing.T) {
t.Setenv("AWS_TIMEOUT", "123")
// without resetting the Once, expect to be reused
co = GetClientOptions()
assert.Equal(t, ClientOptions{Timeout: 42 * time.Millisecond}, co)
})

t.Run("invalid AWS_TIMEOUT", func(t *testing.T) {
t.Setenv("AWS_TIMEOUT", "foo")
// reset the Once
coInit = sync.Once{}
assert.Panics(t, func() {
GetClientOptions()
})
})
}
8 changes: 3 additions & 5 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestEnvMapifiesEnvironment(t *testing.T) {
func TestEnvGetsUpdatedEnvironment(t *testing.T) {
c := &tmplctx{}
assert.Empty(t, c.Env()["FOO"])
require.NoError(t, os.Setenv("FOO", "foo"))
t.Setenv("FOO", "foo")
assert.Equal(t, c.Env()["FOO"], "foo")
}

Expand All @@ -42,17 +42,15 @@ func TestCreateContext(t *testing.T) {
".": {URL: ub},
},
}
os.Setenv("foo", "foo: bar")
defer os.Unsetenv("foo")
t.Setenv("foo", "foo: bar")
c, err = createTmplContext(ctx, []string{"foo"}, d)
require.NoError(t, err)
assert.IsType(t, &tmplctx{}, c)
tctx := c.(*tmplctx)
ds := ((*tctx)["foo"]).(map[string]interface{})
assert.Equal(t, "bar", ds["foo"])

os.Setenv("bar", "bar: baz")
defer os.Unsetenv("bar")
t.Setenv("bar", "bar: baz")
c, err = createTmplContext(ctx, []string{"."}, d)
require.NoError(t, err)
assert.IsType(t, map[string]interface{}{}, c)
Expand Down
69 changes: 33 additions & 36 deletions data/datasource_blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"net/http/httptest"
"net/url"
"os"
"testing"

"github.com/johannesboyne/gofakes3"
Expand Down Expand Up @@ -64,50 +63,48 @@ func TestReadBlob(t *testing.T) {
ts, u := setupTestBucket(t)
defer ts.Close()

os.Setenv("AWS_ANON", "true")
defer os.Unsetenv("AWS_ANON")
t.Run("no authentication", func(t *testing.T) {
t.Setenv("AWS_ANON", "true")

d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil)
require.NoError(t, err)
d, err := NewData([]string{"-d", "data=s3://mybucket/file1?region=us-east-1&disableSSL=true&s3ForcePathStyle=true&type=text/plain&endpoint=" + u.Host}, nil)
require.NoError(t, err)

var expected interface{}
expected = "hello"
out, err := d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)
expected := "hello"
out, err := d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)
})

os.Unsetenv("AWS_ANON")
t.Run("with authentication", func(t *testing.T) {
t.Setenv("AWS_ACCESS_KEY_ID", "fake")
t.Setenv("AWS_SECRET_ACCESS_KEY", "fake")
t.Setenv("AWS_S3_ENDPOINT", u.Host)

os.Setenv("AWS_ACCESS_KEY_ID", "fake")
os.Setenv("AWS_SECRET_ACCESS_KEY", "fake")
defer os.Unsetenv("AWS_ACCESS_KEY_ID")
defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Setenv("AWS_S3_ENDPOINT", u.Host)
defer os.Unsetenv("AWS_S3_ENDPOINT")
d, err := NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)

d, err = NewData([]string{"-d", "data=s3://mybucket/file2?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)
var expected interface{}
expected = map[string]interface{}{"value": "goodbye world"}
out, err := d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)

expected = map[string]interface{}{"value": "goodbye world"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.Equal(t, expected, out)
d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)

d, err = NewData([]string{"-d", "data=s3://mybucket/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)
expected = []interface{}{"dir1/", "file1", "file2", "file3"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)

expected = []interface{}{"dir1/", "file1", "file2", "file3"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)

d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)
d, err = NewData([]string{"-d", "data=s3://mybucket/dir1/?region=us-east-1&disableSSL=true&s3ForcePathStyle=true"}, nil)
require.NoError(t, err)

expected = []interface{}{"file1", "file2"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)
expected = []interface{}{"file1", "file2"}
out, err = d.Datasource("data")
require.NoError(t, err)
assert.EqualValues(t, expected, out)
})
}

func TestBlobURL(t *testing.T) {
Expand Down
7 changes: 2 additions & 5 deletions data/datasource_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package data
import (
"context"
"net/url"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -19,10 +18,8 @@ func TestReadEnv(t *testing.T) {
ctx := context.Background()

content := []byte(`hello world`)
os.Setenv("HELLO_WORLD", "hello world")
defer os.Unsetenv("HELLO_WORLD")
os.Setenv("HELLO_UNIVERSE", "hello universe")
defer os.Unsetenv("HELLO_UNIVERSE")
t.Setenv("HELLO_WORLD", "hello world")
t.Setenv("HELLO_UNIVERSE", "hello universe")

source := &Source{Alias: "foo", URL: mustParseURL("env:HELLO_WORLD")}

Expand Down
45 changes: 22 additions & 23 deletions data/datasource_git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,13 @@ func TestGitAuth(t *testing.T) {
assert.NilError(t, err)
assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a)

os.Setenv("GIT_HTTP_PASSWORD", "swordfish")
defer os.Unsetenv("GIT_HTTP_PASSWORD")
t.Setenv("GIT_HTTP_PASSWORD", "swordfish")
a, err = g.auth(mustParseURL("git+https://[email protected]/foo"))
assert.NilError(t, err)
assert.DeepEqual(t, &http.BasicAuth{Username: "user", Password: "swordfish"}, a)
os.Unsetenv("GIT_HTTP_PASSWORD")

os.Setenv("GIT_HTTP_TOKEN", "mytoken")
defer os.Unsetenv("GIT_HTTP_TOKEN")
t.Setenv("GIT_HTTP_TOKEN", "mytoken")
a, err = g.auth(mustParseURL("git+https://[email protected]/foo"))
assert.NilError(t, err)
assert.DeepEqual(t, &http.TokenAuth{Token: "mytoken"}, a)
Expand All @@ -508,25 +506,26 @@ func TestGitAuth(t *testing.T) {
assert.Equal(t, "git", sa.User)
}

key := string(testdata.PEMBytes["ed25519"])
os.Setenv("GIT_SSH_KEY", key)
defer os.Unsetenv("GIT_SSH_KEY")
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok := a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
os.Unsetenv("GIT_SSH_KEY")

key = base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"])
os.Setenv("GIT_SSH_KEY", key)
defer os.Unsetenv("GIT_SSH_KEY")
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok = a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
os.Unsetenv("GIT_SSH_KEY")
t.Run("plain string ed25519", func(t *testing.T) {
key := string(testdata.PEMBytes["ed25519"])
t.Setenv("GIT_SSH_KEY", key)
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok := a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
})

t.Run("base64 ed25519", func(t *testing.T) {
key := base64.StdEncoding.EncodeToString(testdata.PEMBytes["ed25519"])
t.Setenv("GIT_SSH_KEY", key)
a, err = g.auth(mustParseURL("git+ssh://[email protected]/foo"))
assert.NilError(t, err)
ka, ok := a.(*ssh.PublicKeys)
assert.Equal(t, true, ok)
assert.Equal(t, "git", ka.User)
os.Unsetenv("GIT_SSH_KEY")
})
}

func TestRefFromURL(t *testing.T) {
Expand Down
Loading
Loading