Skip to content

Commit

Permalink
Implement --page-size flag and paginate DNS list commands (#20)
Browse files Browse the repository at this point in the history
* Implement --page-size for dns zone list

* Replicate logic for dns record-set list

* Fix typo

* Improve utils functions descriptions

* Fix test case

* Change mocked server error to be 500 instead of 502
  • Loading branch information
joaopalet authored Nov 23, 2023
1 parent 7c4949a commit 43b4a3d
Show file tree
Hide file tree
Showing 5 changed files with 510 additions and 46 deletions.
67 changes: 55 additions & 12 deletions internal/cmd/dns/record-set/list/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
activeFlag = "is-active"
orderByNameFlag = "order-by-name"
limitFlag = "limit"
pageSizeFlag = "page-size"
)

type flagModel struct {
Expand All @@ -30,6 +31,7 @@ type flagModel struct {
Active *bool
OrderByName *string
Limit *int64
PageSize int64
}

func NewCmd() *cobra.Command {
Expand All @@ -51,23 +53,16 @@ func NewCmd() *cobra.Command {
return fmt.Errorf("authentication failed, please run \"stackit auth login\" or \"stackit auth activate-service-account\"")
}

// Call API
req := buildRequest(ctx, model, apiClient)
resp, err := req.Execute()
// Fetch record sets
recordSets, err := fetchRecordSets(ctx, model, apiClient)
if err != nil {
return fmt.Errorf("get DNS record sets: %w", err)
return err
}
recordSets := *resp.RrSets
if len(recordSets) == 0 {
cmd.Printf("No record-sets found for zone with ID %s\n", model.ZoneId)
cmd.Printf("No record sets found for zone %s in project with ID %s\n", model.ZoneId, model.ProjectId)
return nil
}

// Truncate output
if model.Limit != nil && len(recordSets) > int(*model.Limit) {
recordSets = recordSets[:*model.Limit]
}

// Show output as table
table := tables.NewTable()
table.SetHeader("ID", "Name", "Type", "State")
Expand All @@ -94,6 +89,7 @@ func configureFlags(cmd *cobra.Command) {
cmd.Flags().Var(flags.EnumBoolFlag(), activeFlag, fmt.Sprintf("Filter by active status, one of %q", activeFlagOptions))
cmd.Flags().Var(flags.EnumFlag(true, orderByNameFlagOptions...), orderByNameFlag, fmt.Sprintf("Order by name, one of %q", orderByNameFlagOptions))
cmd.Flags().Int64(limitFlag, 0, "Maximum number of entries to list")
cmd.Flags().Int64(pageSizeFlag, 100, "Number of items fetched in each API call. Does not affect the number of items in the command output")

err := utils.MarkFlagsRequired(cmd, zoneIdFlag)
cobra.CheckErr(err)
Expand All @@ -110,17 +106,26 @@ func parseFlags(cmd *cobra.Command) (*flagModel, error) {
return nil, fmt.Errorf("limit must be greater than 0")
}

pageSize, err := utils.FlagWithDefaultToInt64Value(cmd, pageSizeFlag)
if err != nil {
return nil, fmt.Errorf("parse %s flag: %w", pageSizeFlag, err)
}
if pageSize < 1 {
return nil, fmt.Errorf("page size must be greater than 0")
}

return &flagModel{
ProjectId: projectId,
ZoneId: utils.FlagToStringValue(cmd, zoneIdFlag),
NameLike: utils.FlagToStringPointer(cmd, nameLikeFlag),
Active: utils.FlagToBoolPointer(cmd, activeFlag),
OrderByName: utils.FlagToStringPointer(cmd, orderByNameFlag),
Limit: utils.FlagToInt64Pointer(cmd, limitFlag),
PageSize: pageSize,
}, nil
}

func buildRequest(ctx context.Context, model *flagModel, apiClient *dns.APIClient) dns.ApiGetRecordSetsRequest {
func buildRequest(ctx context.Context, model *flagModel, apiClient dnsClient, page int) dns.ApiGetRecordSetsRequest {
req := apiClient.GetRecordSets(ctx, model.ProjectId, model.ZoneId)
if model.NameLike != nil {
req = req.NameLike(*model.NameLike)
Expand All @@ -131,5 +136,43 @@ func buildRequest(ctx context.Context, model *flagModel, apiClient *dns.APIClien
if model.OrderByName != nil {
req = req.OrderByName(strings.ToUpper(*model.OrderByName))
}
req = req.PageSize(int32(model.PageSize))
req = req.Page(int32(page))
return req
}

type dnsClient interface {
GetRecordSets(ctx context.Context, projectId, zoneId string) dns.ApiGetRecordSetsRequest
}

func fetchRecordSets(ctx context.Context, model *flagModel, apiClient dnsClient) ([]dns.RecordSet, error) {
if model.Limit != nil && *model.Limit < model.PageSize {
model.PageSize = *model.Limit
}
page := 1
recordSets := []dns.RecordSet{}
for {
// Call API
req := buildRequest(ctx, model, apiClient, page)
resp, err := req.Execute()
if err != nil {
return nil, fmt.Errorf("get DNS record sets: %w", err)
}
respRecordSets := *resp.RrSets
if len(respRecordSets) == 0 {
break
}
recordSets = append(recordSets, respRecordSets...)
// Stop if no more pages
if len(respRecordSets) < int(model.PageSize) {
break
}
// Stop and truncate if limit is reached
if model.Limit != nil && len(recordSets) >= int(*model.Limit) {
recordSets = recordSets[:*model.Limit]
break
}
page++
}
return recordSets, nil
}
202 changes: 194 additions & 8 deletions internal/cmd/dns/record-set/list/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ package list

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strconv"
"testing"

"github.com/stackitcloud/stackit-cli/internal/pkg/globalflags"
"github.com/stackitcloud/stackit-cli/internal/pkg/utils"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/spf13/cobra"
"github.com/stackitcloud/stackit-cli/internal/pkg/globalflags"
"github.com/stackitcloud/stackit-cli/internal/pkg/utils"
sdkConfig "github.com/stackitcloud/stackit-sdk-go/core/config"
"github.com/stackitcloud/stackit-sdk-go/services/dns"
)

Expand All @@ -30,7 +34,6 @@ func fixtureFlagValues(mods ...func(flagValues map[string]string)) map[string]st
nameLikeFlag: "some-pattern",
activeFlag: "true",
orderByNameFlag: "asc",
limitFlag: "10",
}
for _, mod := range mods {
mod(flagValues)
Expand All @@ -45,7 +48,7 @@ func fixtureFlagModel(mods ...func(model *flagModel)) *flagModel {
NameLike: utils.Ptr("some-pattern"),
Active: utils.Ptr(true),
OrderByName: utils.Ptr("asc"),
Limit: utils.Ptr(int64(10)),
PageSize: 100,
}
for _, mod := range mods {
mod(model)
Expand All @@ -58,6 +61,7 @@ func fixtureRequest(mods ...func(request *dns.ApiGetRecordSetsRequest)) dns.ApiG
request = request.NameLike("some-pattern")
request = request.ActiveEq(true)
request = request.OrderByName("ASC")
request = request.PageSize(100)
for _, mod := range mods {
mod(&request)
}
Expand Down Expand Up @@ -92,6 +96,7 @@ func TestParseFlags(t *testing.T) {
expectedModel: &flagModel{
ProjectId: testProjectId,
ZoneId: testZoneId,
PageSize: 100, // Default value
},
},
{
Expand Down Expand Up @@ -239,26 +244,36 @@ func TestBuildRequest(t *testing.T) {
tests := []struct {
description string
model *flagModel
page int
expectedRequest dns.ApiGetRecordSetsRequest
}{
{
description: "base",
model: fixtureFlagModel(),
expectedRequest: fixtureRequest(),
page: 1,
expectedRequest: fixtureRequest().Page(1),
},
{
description: "base 2",
model: fixtureFlagModel(),
page: 10,
expectedRequest: fixtureRequest().Page(10),
},
{
description: "required fields only",
model: &flagModel{
ProjectId: testProjectId,
ZoneId: testZoneId,
PageSize: 10,
},
expectedRequest: testClient.GetRecordSets(testCtx, testProjectId, testZoneId),
page: 1,
expectedRequest: testClient.GetRecordSets(testCtx, testProjectId, testZoneId).Page(1).PageSize(10),
},
}

for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
request := buildRequest(testCtx, tt.model, testClient)
request := buildRequest(testCtx, tt.model, testClient, tt.page)

diff := cmp.Diff(request, tt.expectedRequest,
cmp.AllowUnexported(tt.expectedRequest),
Expand All @@ -270,3 +285,174 @@ func TestBuildRequest(t *testing.T) {
})
}
}

func TestFetchRecordSets(t *testing.T) {
tests := []struct {
description string
model *flagModel
totalItems int
apiCallFails bool
expectedNumAPICalls int
expectedNumItems int
}{
{
description: "no limit and pageSize>totalItems",
model: fixtureFlagModel(),
totalItems: 10,
expectedNumAPICalls: 1,
apiCallFails: false,
expectedNumItems: 10,
},
{
description: "no limit and pageSize<totalItems",
model: fixtureFlagModel(),
totalItems: 320,
expectedNumAPICalls: 4,
apiCallFails: false,
expectedNumItems: 320,
},
{
description: "no limit and pageSize<totalItems 2",
model: fixtureFlagModel(),
totalItems: 200,
expectedNumAPICalls: 3, // Last call will return no items
apiCallFails: false,
expectedNumItems: 200,
},
{
description: "no limit and pageSize=totalItems",
model: fixtureFlagModel(),
totalItems: 100,
expectedNumAPICalls: 2, // Last call will return no items
apiCallFails: false,
expectedNumItems: 100,
},
{
description: "limit<pageSize",
model: fixtureFlagModel(func(model *flagModel) {
model.Limit = utils.Ptr(int64(10))
}),
totalItems: 100,
expectedNumAPICalls: 1,
apiCallFails: false,
expectedNumItems: 10,
},
{
description: "limit>totalItems and pageSize>totalItems",
model: fixtureFlagModel(func(model *flagModel) {
model.Limit = utils.Ptr(int64(200))
model.PageSize = 300
}),
totalItems: 50,
expectedNumAPICalls: 1,
apiCallFails: false,
expectedNumItems: 50,
},
{
description: "limit>totalItems and pageSize<totalItems",
model: fixtureFlagModel(func(model *flagModel) {
model.Limit = utils.Ptr(int64(200))
model.PageSize = 30
}),
totalItems: 50,
expectedNumAPICalls: 2,
apiCallFails: false,
expectedNumItems: 50,
},
{
description: "request fails",
model: fixtureFlagModel(),
totalItems: 100,
apiCallFails: true,
},
}

for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
numAPICalls := 0
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numAPICalls++

w.Header().Set("Content-Type", "application/json")
if tt.apiCallFails {
w.WriteHeader(http.StatusInternalServerError)
_, err := w.Write([]byte("{\"message\": \"Something bad happened\""))
if err != nil {
t.Errorf("Failed to write bad response: %v", err)
}
return
}

query := r.URL.Query()
pageStr := query.Get("page")
if pageStr == "" {
t.Errorf("Expected query param page to be set")
}
page, err := strconv.Atoi(pageStr)
if err != nil {
t.Errorf("Failed to parse query param page: %v", err)
}
pageSizeStr := query.Get("pageSize")
if pageSizeStr == "" {
t.Errorf("Expected query param pageSize to be set")
}
pageSize, err := strconv.Atoi(pageSizeStr)
if err != nil {
t.Errorf("Failed to parse query param pageSize: %v", err)
}

offset := (page - 1) * pageSize

var numItemsToReturn int
if offset >= tt.totalItems {
numItemsToReturn = 0 // Total items reached
} else if offset+pageSize < tt.totalItems {
numItemsToReturn = pageSize // Full intermediate page
} else {
numItemsToReturn = tt.totalItems - offset // Last page
}

recordSets := make([]dns.RecordSet, numItemsToReturn)
mockedResp := dns.RecordSetsResponse{
RrSets: &recordSets,
}

mockedRespBytes, err := json.Marshal(mockedResp)
if err != nil {
t.Fatalf("Failed to marshal mocked response: %v", err)
}

_, err = w.Write(mockedRespBytes)
if err != nil {
t.Errorf("Failed to write response: %v", err)
}
})
mockedServer := httptest.NewServer(handler)
defer mockedServer.Close()
client, err := dns.NewAPIClient(
sdkConfig.WithEndpoint(mockedServer.URL),
sdkConfig.WithoutAuthentication(),
)
if err != nil {
t.Fatalf("Failed to initialize client: %v", err)
}

recordSets, err := fetchRecordSets(testCtx, tt.model, client)
if err != nil {
if !tt.apiCallFails {
t.Fatalf("did not fail on invalid input")
}
return
}
if err == nil && tt.apiCallFails {
t.Fatalf("did not fail on invalid input")
}
if numAPICalls != tt.expectedNumAPICalls {
t.Fatalf("Expected %d API calls, got %d", tt.expectedNumAPICalls, numAPICalls)
}
if len(recordSets) != tt.expectedNumItems {
t.Fatalf("Expected %d recordSets, got %d", tt.totalItems, len(recordSets))
}
})
}
}
Loading

0 comments on commit 43b4a3d

Please sign in to comment.