From 43b4a3dbfe6abfd47c453251adcf4558783bb1a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Palet?= Date: Thu, 23 Nov 2023 18:33:20 +0100 Subject: [PATCH] Implement `--page-size` flag and paginate DNS list commands (#20) * Implement --page-size for dns zone list * Replicate logic for dns record-set list * Fix typo * Improve utils functions descriptions * Fix test case * Change mocked server error to be 500 instead of 502 --- internal/cmd/dns/record-set/list/list.go | 67 ++++-- internal/cmd/dns/record-set/list/list_test.go | 202 +++++++++++++++++- internal/cmd/dns/zone/list/list.go | 65 +++++- internal/cmd/dns/zone/list/list_test.go | 202 +++++++++++++++++- internal/pkg/utils/utils.go | 20 +- 5 files changed, 510 insertions(+), 46 deletions(-) diff --git a/internal/cmd/dns/record-set/list/list.go b/internal/cmd/dns/record-set/list/list.go index 0d8082a0..3479b213 100644 --- a/internal/cmd/dns/record-set/list/list.go +++ b/internal/cmd/dns/record-set/list/list.go @@ -21,6 +21,7 @@ const ( activeFlag = "is-active" orderByNameFlag = "order-by-name" limitFlag = "limit" + pageSizeFlag = "page-size" ) type flagModel struct { @@ -30,6 +31,7 @@ type flagModel struct { Active *bool OrderByName *string Limit *int64 + PageSize int64 } func NewCmd() *cobra.Command { @@ -51,23 +53,16 @@ func NewCmd() *cobra.Command { return fmt.Errorf("authentication failed, please run \"stackit auth login\" or \"stackit auth activate-service-account\"") } - // Call API - req := buildRequest(ctx, model, apiClient) - resp, err := req.Execute() + // Fetch record sets + recordSets, err := fetchRecordSets(ctx, model, apiClient) if err != nil { - return fmt.Errorf("get DNS record sets: %w", err) + return err } - recordSets := *resp.RrSets if len(recordSets) == 0 { - cmd.Printf("No record-sets found for zone with ID %s\n", model.ZoneId) + cmd.Printf("No record sets found for zone %s in project with ID %s\n", model.ZoneId, model.ProjectId) return nil } - // Truncate output - if model.Limit != nil && len(recordSets) > int(*model.Limit) { - recordSets = recordSets[:*model.Limit] - } - // Show output as table table := tables.NewTable() table.SetHeader("ID", "Name", "Type", "State") @@ -94,6 +89,7 @@ func configureFlags(cmd *cobra.Command) { cmd.Flags().Var(flags.EnumBoolFlag(), activeFlag, fmt.Sprintf("Filter by active status, one of %q", activeFlagOptions)) cmd.Flags().Var(flags.EnumFlag(true, orderByNameFlagOptions...), orderByNameFlag, fmt.Sprintf("Order by name, one of %q", orderByNameFlagOptions)) cmd.Flags().Int64(limitFlag, 0, "Maximum number of entries to list") + cmd.Flags().Int64(pageSizeFlag, 100, "Number of items fetched in each API call. Does not affect the number of items in the command output") err := utils.MarkFlagsRequired(cmd, zoneIdFlag) cobra.CheckErr(err) @@ -110,6 +106,14 @@ func parseFlags(cmd *cobra.Command) (*flagModel, error) { return nil, fmt.Errorf("limit must be greater than 0") } + pageSize, err := utils.FlagWithDefaultToInt64Value(cmd, pageSizeFlag) + if err != nil { + return nil, fmt.Errorf("parse %s flag: %w", pageSizeFlag, err) + } + if pageSize < 1 { + return nil, fmt.Errorf("page size must be greater than 0") + } + return &flagModel{ ProjectId: projectId, ZoneId: utils.FlagToStringValue(cmd, zoneIdFlag), @@ -117,10 +121,11 @@ func parseFlags(cmd *cobra.Command) (*flagModel, error) { Active: utils.FlagToBoolPointer(cmd, activeFlag), OrderByName: utils.FlagToStringPointer(cmd, orderByNameFlag), Limit: utils.FlagToInt64Pointer(cmd, limitFlag), + PageSize: pageSize, }, nil } -func buildRequest(ctx context.Context, model *flagModel, apiClient *dns.APIClient) dns.ApiGetRecordSetsRequest { +func buildRequest(ctx context.Context, model *flagModel, apiClient dnsClient, page int) dns.ApiGetRecordSetsRequest { req := apiClient.GetRecordSets(ctx, model.ProjectId, model.ZoneId) if model.NameLike != nil { req = req.NameLike(*model.NameLike) @@ -131,5 +136,43 @@ func buildRequest(ctx context.Context, model *flagModel, apiClient *dns.APIClien if model.OrderByName != nil { req = req.OrderByName(strings.ToUpper(*model.OrderByName)) } + req = req.PageSize(int32(model.PageSize)) + req = req.Page(int32(page)) return req } + +type dnsClient interface { + GetRecordSets(ctx context.Context, projectId, zoneId string) dns.ApiGetRecordSetsRequest +} + +func fetchRecordSets(ctx context.Context, model *flagModel, apiClient dnsClient) ([]dns.RecordSet, error) { + if model.Limit != nil && *model.Limit < model.PageSize { + model.PageSize = *model.Limit + } + page := 1 + recordSets := []dns.RecordSet{} + for { + // Call API + req := buildRequest(ctx, model, apiClient, page) + resp, err := req.Execute() + if err != nil { + return nil, fmt.Errorf("get DNS record sets: %w", err) + } + respRecordSets := *resp.RrSets + if len(respRecordSets) == 0 { + break + } + recordSets = append(recordSets, respRecordSets...) + // Stop if no more pages + if len(respRecordSets) < int(model.PageSize) { + break + } + // Stop and truncate if limit is reached + if model.Limit != nil && len(recordSets) >= int(*model.Limit) { + recordSets = recordSets[:*model.Limit] + break + } + page++ + } + return recordSets, nil +} diff --git a/internal/cmd/dns/record-set/list/list_test.go b/internal/cmd/dns/record-set/list/list_test.go index b0d9655c..6ea97537 100644 --- a/internal/cmd/dns/record-set/list/list_test.go +++ b/internal/cmd/dns/record-set/list/list_test.go @@ -2,15 +2,19 @@ package list import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" "testing" - "github.com/stackitcloud/stackit-cli/internal/pkg/globalflags" - "github.com/stackitcloud/stackit-cli/internal/pkg/utils" - "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/spf13/cobra" + "github.com/stackitcloud/stackit-cli/internal/pkg/globalflags" + "github.com/stackitcloud/stackit-cli/internal/pkg/utils" + sdkConfig "github.com/stackitcloud/stackit-sdk-go/core/config" "github.com/stackitcloud/stackit-sdk-go/services/dns" ) @@ -30,7 +34,6 @@ func fixtureFlagValues(mods ...func(flagValues map[string]string)) map[string]st nameLikeFlag: "some-pattern", activeFlag: "true", orderByNameFlag: "asc", - limitFlag: "10", } for _, mod := range mods { mod(flagValues) @@ -45,7 +48,7 @@ func fixtureFlagModel(mods ...func(model *flagModel)) *flagModel { NameLike: utils.Ptr("some-pattern"), Active: utils.Ptr(true), OrderByName: utils.Ptr("asc"), - Limit: utils.Ptr(int64(10)), + PageSize: 100, } for _, mod := range mods { mod(model) @@ -58,6 +61,7 @@ func fixtureRequest(mods ...func(request *dns.ApiGetRecordSetsRequest)) dns.ApiG request = request.NameLike("some-pattern") request = request.ActiveEq(true) request = request.OrderByName("ASC") + request = request.PageSize(100) for _, mod := range mods { mod(&request) } @@ -92,6 +96,7 @@ func TestParseFlags(t *testing.T) { expectedModel: &flagModel{ ProjectId: testProjectId, ZoneId: testZoneId, + PageSize: 100, // Default value }, }, { @@ -239,26 +244,36 @@ func TestBuildRequest(t *testing.T) { tests := []struct { description string model *flagModel + page int expectedRequest dns.ApiGetRecordSetsRequest }{ { description: "base", model: fixtureFlagModel(), - expectedRequest: fixtureRequest(), + page: 1, + expectedRequest: fixtureRequest().Page(1), + }, + { + description: "base 2", + model: fixtureFlagModel(), + page: 10, + expectedRequest: fixtureRequest().Page(10), }, { description: "required fields only", model: &flagModel{ ProjectId: testProjectId, ZoneId: testZoneId, + PageSize: 10, }, - expectedRequest: testClient.GetRecordSets(testCtx, testProjectId, testZoneId), + page: 1, + expectedRequest: testClient.GetRecordSets(testCtx, testProjectId, testZoneId).Page(1).PageSize(10), }, } for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { - request := buildRequest(testCtx, tt.model, testClient) + request := buildRequest(testCtx, tt.model, testClient, tt.page) diff := cmp.Diff(request, tt.expectedRequest, cmp.AllowUnexported(tt.expectedRequest), @@ -270,3 +285,174 @@ func TestBuildRequest(t *testing.T) { }) } } + +func TestFetchRecordSets(t *testing.T) { + tests := []struct { + description string + model *flagModel + totalItems int + apiCallFails bool + expectedNumAPICalls int + expectedNumItems int + }{ + { + description: "no limit and pageSize>totalItems", + model: fixtureFlagModel(), + totalItems: 10, + expectedNumAPICalls: 1, + apiCallFails: false, + expectedNumItems: 10, + }, + { + description: "no limit and pageSizetotalItems and pageSize>totalItems", + model: fixtureFlagModel(func(model *flagModel) { + model.Limit = utils.Ptr(int64(200)) + model.PageSize = 300 + }), + totalItems: 50, + expectedNumAPICalls: 1, + apiCallFails: false, + expectedNumItems: 50, + }, + { + description: "limit>totalItems and pageSize= tt.totalItems { + numItemsToReturn = 0 // Total items reached + } else if offset+pageSize < tt.totalItems { + numItemsToReturn = pageSize // Full intermediate page + } else { + numItemsToReturn = tt.totalItems - offset // Last page + } + + recordSets := make([]dns.RecordSet, numItemsToReturn) + mockedResp := dns.RecordSetsResponse{ + RrSets: &recordSets, + } + + mockedRespBytes, err := json.Marshal(mockedResp) + if err != nil { + t.Fatalf("Failed to marshal mocked response: %v", err) + } + + _, err = w.Write(mockedRespBytes) + if err != nil { + t.Errorf("Failed to write response: %v", err) + } + }) + mockedServer := httptest.NewServer(handler) + defer mockedServer.Close() + client, err := dns.NewAPIClient( + sdkConfig.WithEndpoint(mockedServer.URL), + sdkConfig.WithoutAuthentication(), + ) + if err != nil { + t.Fatalf("Failed to initialize client: %v", err) + } + + recordSets, err := fetchRecordSets(testCtx, tt.model, client) + if err != nil { + if !tt.apiCallFails { + t.Fatalf("did not fail on invalid input") + } + return + } + if err == nil && tt.apiCallFails { + t.Fatalf("did not fail on invalid input") + } + if numAPICalls != tt.expectedNumAPICalls { + t.Fatalf("Expected %d API calls, got %d", tt.expectedNumAPICalls, numAPICalls) + } + if len(recordSets) != tt.expectedNumItems { + t.Fatalf("Expected %d recordSets, got %d", tt.totalItems, len(recordSets)) + } + }) + } +} diff --git a/internal/cmd/dns/zone/list/list.go b/internal/cmd/dns/zone/list/list.go index 60bd2133..55d0d0ad 100644 --- a/internal/cmd/dns/zone/list/list.go +++ b/internal/cmd/dns/zone/list/list.go @@ -20,6 +20,7 @@ const ( activeFlag = "is-active" orderByNameFlag = "order-by-name" limitFlag = "limit" + pageSizeFlag = "page-size" ) type flagModel struct { @@ -28,6 +29,7 @@ type flagModel struct { Active *bool OrderByName *string Limit *int64 + PageSize int64 } func NewCmd() *cobra.Command { @@ -49,23 +51,16 @@ func NewCmd() *cobra.Command { return fmt.Errorf("authentication failed, please run \"stackit auth login\" or \"stackit auth activate-service-account\"") } - // Call API - req := buildRequest(ctx, model, apiClient) - resp, err := req.Execute() + // Fetch zones + zones, err := fetchZones(ctx, model, apiClient) if err != nil { - return fmt.Errorf("get DNS zones: %w", err) + return err } - zones := *resp.Zones if len(zones) == 0 { cmd.Printf("No zones found for project with ID %s\n", model.ProjectId) return nil } - // Truncate output - if model.Limit != nil && len(zones) > int(*model.Limit) { - zones = zones[:*model.Limit] - } - // Show output as table table := tables.NewTable() table.SetHeader("ID", "NAME", "DNS_NAME", "STATE") @@ -90,6 +85,7 @@ func configureFlags(cmd *cobra.Command) { cmd.Flags().Var(flags.EnumBoolFlag(), activeFlag, fmt.Sprintf("Filter by active status, one of %q", activeFlagOptions)) cmd.Flags().Var(flags.EnumFlag(true, orderByNameFlagOptions...), orderByNameFlag, fmt.Sprintf("Order by name, one of %q", orderByNameFlagOptions)) cmd.Flags().Int64(limitFlag, 0, "Maximum number of entries to list") + cmd.Flags().Int64(pageSizeFlag, 100, "Number of items fetched in each API call. Does not affect the number of items in the command output") } func parseFlags(cmd *cobra.Command) (*flagModel, error) { @@ -103,16 +99,25 @@ func parseFlags(cmd *cobra.Command) (*flagModel, error) { return nil, fmt.Errorf("limit must be greater than 0") } + pageSize, err := utils.FlagWithDefaultToInt64Value(cmd, pageSizeFlag) + if err != nil { + return nil, fmt.Errorf("parse %s flag: %w", pageSizeFlag, err) + } + if pageSize < 1 { + return nil, fmt.Errorf("page size must be greater than 0") + } + return &flagModel{ ProjectId: projectId, NameLike: utils.FlagToStringPointer(cmd, nameLikeFlag), Active: utils.FlagToBoolPointer(cmd, activeFlag), OrderByName: utils.FlagToStringPointer(cmd, orderByNameFlag), Limit: limit, + PageSize: pageSize, }, nil } -func buildRequest(ctx context.Context, model *flagModel, apiClient *dns.APIClient) dns.ApiGetZonesRequest { +func buildRequest(ctx context.Context, model *flagModel, apiClient dnsClient, page int) dns.ApiGetZonesRequest { req := apiClient.GetZones(ctx, model.ProjectId) if model.NameLike != nil { req = req.NameLike(*model.NameLike) @@ -123,5 +128,43 @@ func buildRequest(ctx context.Context, model *flagModel, apiClient *dns.APIClien if model.OrderByName != nil { req = req.OrderByName(strings.ToUpper(*model.OrderByName)) } + req = req.PageSize(int32(model.PageSize)) + req = req.Page(int32(page)) return req } + +type dnsClient interface { + GetZones(ctx context.Context, projectId string) dns.ApiGetZonesRequest +} + +func fetchZones(ctx context.Context, model *flagModel, apiClient dnsClient) ([]dns.Zone, error) { + if model.Limit != nil && *model.Limit < model.PageSize { + model.PageSize = *model.Limit + } + page := 1 + zones := []dns.Zone{} + for { + // Call API + req := buildRequest(ctx, model, apiClient, page) + resp, err := req.Execute() + if err != nil { + return nil, fmt.Errorf("get DNS zones: %w", err) + } + respZones := *resp.Zones + if len(respZones) == 0 { + break + } + zones = append(zones, respZones...) + // Stop if no more pages + if len(respZones) < int(model.PageSize) { + break + } + // Stop and truncate if limit is reached + if model.Limit != nil && len(zones) >= int(*model.Limit) { + zones = zones[:*model.Limit] + break + } + page++ + } + return zones, nil +} diff --git a/internal/cmd/dns/zone/list/list_test.go b/internal/cmd/dns/zone/list/list_test.go index e54432b6..003ac320 100644 --- a/internal/cmd/dns/zone/list/list_test.go +++ b/internal/cmd/dns/zone/list/list_test.go @@ -2,15 +2,19 @@ package list import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" "testing" - "github.com/stackitcloud/stackit-cli/internal/pkg/globalflags" - "github.com/stackitcloud/stackit-cli/internal/pkg/utils" - "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/spf13/cobra" + "github.com/stackitcloud/stackit-cli/internal/pkg/globalflags" + "github.com/stackitcloud/stackit-cli/internal/pkg/utils" + sdkConfig "github.com/stackitcloud/stackit-sdk-go/core/config" "github.com/stackitcloud/stackit-sdk-go/services/dns" ) @@ -28,7 +32,6 @@ func fixtureFlagValues(mods ...func(flagValues map[string]string)) map[string]st nameLikeFlag: "some-pattern", activeFlag: "true", orderByNameFlag: "asc", - limitFlag: "10", } for _, mod := range mods { mod(flagValues) @@ -42,7 +45,7 @@ func fixtureFlagModel(mods ...func(model *flagModel)) *flagModel { NameLike: utils.Ptr("some-pattern"), Active: utils.Ptr(true), OrderByName: utils.Ptr("asc"), - Limit: utils.Ptr(int64(10)), + PageSize: 100, } for _, mod := range mods { mod(model) @@ -55,6 +58,7 @@ func fixtureRequest(mods ...func(request *dns.ApiGetZonesRequest)) dns.ApiGetZon request = request.NameLike("some-pattern") request = request.ActiveEq(true) request = request.OrderByName("ASC") + request = request.PageSize(100) for _, mod := range mods { mod(&request) } @@ -87,6 +91,7 @@ func TestParseFlags(t *testing.T) { isValid: true, expectedModel: &flagModel{ ProjectId: testProjectId, + PageSize: 100, // Default value }, }, { @@ -234,25 +239,35 @@ func TestBuildRequest(t *testing.T) { tests := []struct { description string model *flagModel + page int expectedRequest dns.ApiGetZonesRequest }{ { description: "base", model: fixtureFlagModel(), - expectedRequest: fixtureRequest(), + page: 1, + expectedRequest: fixtureRequest().Page(1), + }, + { + description: "base 2", + model: fixtureFlagModel(), + page: 10, + expectedRequest: fixtureRequest().Page(10), }, { description: "required fields only", model: &flagModel{ ProjectId: testProjectId, + PageSize: 100, }, - expectedRequest: testClient.GetZones(testCtx, testProjectId), + page: 1, + expectedRequest: testClient.GetZones(testCtx, testProjectId).Page(1).PageSize(100), }, } for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { - request := buildRequest(testCtx, tt.model, testClient) + request := buildRequest(testCtx, tt.model, testClient, tt.page) diff := cmp.Diff(request, tt.expectedRequest, cmp.AllowUnexported(tt.expectedRequest), @@ -264,3 +279,174 @@ func TestBuildRequest(t *testing.T) { }) } } + +func TestFetchZones(t *testing.T) { + tests := []struct { + description string + model *flagModel + totalItems int + apiCallFails bool + expectedNumAPICalls int + expectedNumItems int + }{ + { + description: "no limit and pageSize>totalItems", + model: fixtureFlagModel(), + totalItems: 10, + expectedNumAPICalls: 1, + apiCallFails: false, + expectedNumItems: 10, + }, + { + description: "no limit and pageSizetotalItems and pageSize>totalItems", + model: fixtureFlagModel(func(model *flagModel) { + model.Limit = utils.Ptr(int64(200)) + model.PageSize = 300 + }), + totalItems: 50, + expectedNumAPICalls: 1, + apiCallFails: false, + expectedNumItems: 50, + }, + { + description: "limit>totalItems and pageSize= tt.totalItems { + numItemsToReturn = 0 // Total items reached + } else if offset+pageSize < tt.totalItems { + numItemsToReturn = pageSize // Full intermediate page + } else { + numItemsToReturn = tt.totalItems - offset // Last page + } + + zones := make([]dns.Zone, numItemsToReturn) + mockedResp := dns.ZonesResponse{ + Zones: &zones, + } + + mockedRespBytes, err := json.Marshal(mockedResp) + if err != nil { + t.Fatalf("Failed to marshal mocked response: %v", err) + } + + _, err = w.Write(mockedRespBytes) + if err != nil { + t.Errorf("Failed to write response: %v", err) + } + }) + mockedServer := httptest.NewServer(handler) + defer mockedServer.Close() + client, err := dns.NewAPIClient( + sdkConfig.WithEndpoint(mockedServer.URL), + sdkConfig.WithoutAuthentication(), + ) + if err != nil { + t.Fatalf("Failed to initialize client: %v", err) + } + + zones, err := fetchZones(testCtx, tt.model, client) + if err != nil { + if !tt.apiCallFails { + t.Fatalf("did not fail on invalid input") + } + return + } + if err == nil && tt.apiCallFails { + t.Fatalf("did not fail on invalid input") + } + if numAPICalls != tt.expectedNumAPICalls { + t.Fatalf("Expected %d API calls, got %d", tt.expectedNumAPICalls, numAPICalls) + } + if len(zones) != tt.expectedNumItems { + t.Fatalf("Expected %d zones, got %d", tt.totalItems, len(zones)) + } + }) + } +} diff --git a/internal/pkg/utils/utils.go b/internal/pkg/utils/utils.go index 0f177442..fc0d5489 100644 --- a/internal/pkg/utils/utils.go +++ b/internal/pkg/utils/utils.go @@ -4,8 +4,8 @@ import ( "github.com/spf13/cobra" ) +// Returns the flag's value as a string. // Returns "" if the flag is not set, if its value can not be converted to string, or if the flag does not exist. -// Otherwise, returns the flag's value as a string func FlagToStringValue(cmd *cobra.Command, flag string) string { value, err := cmd.Flags().GetString(flag) if err != nil { @@ -17,8 +17,8 @@ func FlagToStringValue(cmd *cobra.Command, flag string) string { return "" } +// Returns the flag's value as a bool. // Returns "false" if its value can not be converted to bool, or if the flag does not exist. -// Otherwise, returns flag's value as a bool func FlagToBoolValue(cmd *cobra.Command, flag string) bool { value, err := cmd.Flags().GetBool(flag) if err != nil { @@ -27,8 +27,8 @@ func FlagToBoolValue(cmd *cobra.Command, flag string) bool { return value } +// Returns the flag's value as a []string. // Returns nil if the flag is not set, if its value can not be converted to []string, or if the flag does not exist. -// Otherwise, returns the flag's value. func FlagToStringSliceValue(cmd *cobra.Command, flag string) []string { value, err := cmd.Flags().GetStringSlice(flag) if err != nil { @@ -40,8 +40,8 @@ func FlagToStringSliceValue(cmd *cobra.Command, flag string) []string { return nil } +// Returns a pointer to the flag's value. // Returns nil if the flag is not set, if its value can not be converted to int64, or if the flag does not exist. -// Otherwise, returns a pointer to the flag's value. func FlagToInt64Pointer(cmd *cobra.Command, flag string) *int64 { value, err := cmd.Flags().GetInt64(flag) if err != nil { @@ -53,8 +53,8 @@ func FlagToInt64Pointer(cmd *cobra.Command, flag string) *int64 { return nil } +// Returns a pointer to the flag's value. // Returns nil if the flag is not set, if its value can not be converted to string, or if the flag does not exist. -// Otherwise, returns a pointer to the flag's value. func FlagToStringPointer(cmd *cobra.Command, flag string) *string { value, err := cmd.Flags().GetString(flag) if err != nil { @@ -66,8 +66,8 @@ func FlagToStringPointer(cmd *cobra.Command, flag string) *string { return nil } +// Returns a pointer to the flag's value. // Returns nil if the flag is not set, if its value can not be converted to []string, or if the flag does not exist. -// Otherwise, returns a pointer to the flag's value. func FlagToStringSlicePointer(cmd *cobra.Command, flag string) *[]string { value, err := cmd.Flags().GetStringSlice(flag) if err != nil { @@ -79,8 +79,8 @@ func FlagToStringSlicePointer(cmd *cobra.Command, flag string) *[]string { return nil } +// Returns a pointer to the flag's value. // Returns nil if the flag is not set, if its value can not be converted to bool, or if the flag does not exist. -// Otherwise, returns a pointer to the flag's value. func FlagToBoolPointer(cmd *cobra.Command, flag string) *bool { value, err := cmd.Flags().GetBool(flag) if err != nil { @@ -92,6 +92,12 @@ func FlagToBoolPointer(cmd *cobra.Command, flag string) *bool { return nil } +// Returns the int64 value set on the flag. If no value is set, returns the flag's default value. +// An error is returned if the flag value can not be converted to int64 or if the flag does not exist. +func FlagWithDefaultToInt64Value(cmd *cobra.Command, flag string) (int64, error) { + return cmd.Flags().GetInt64(flag) +} + // Marks all given flags as required, causing the command to report an error if invoked without them. func MarkFlagsRequired(cmd *cobra.Command, flags ...string) error { for _, flag := range flags {