From 6f3ad2ac6918192b9a0b9f8400839cb191da77b2 Mon Sep 17 00:00:00 2001 From: FingerLeader Date: Thu, 1 Aug 2024 15:42:07 +0800 Subject: [PATCH 1/3] add public endpoint disabled Signed-off-by: FingerLeader --- internal/cli/serverless/create.go | 14 ++++ internal/cli/serverless/update.go | 27 +++++-- internal/flag/flag.go | 127 +++++++++++++++--------------- 3 files changed, 99 insertions(+), 69 deletions(-) diff --git a/internal/cli/serverless/create.go b/internal/cli/serverless/create.go index a984525d..9fddbf1b 100644 --- a/internal/cli/serverless/create.go +++ b/internal/cli/serverless/create.go @@ -124,6 +124,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command { var projectID string var spendingLimitMonthly int32 var encryption bool + var publicEndpointDisabled bool if opts.interactive { cmd.Annotations[telemetry.InteractiveMode] = "true" if !h.IOStreams.CanPrompt { @@ -243,6 +244,10 @@ func CreateCmd(h *internal.Helper) *cobra.Command { if err != nil { return errors.Trace(err) } + publicEndpointDisabled, err = cmd.Flags().GetBool(flag.PublicEndpointDisabled) + if err != nil { + return errors.Trace(err) + } // check clusterName err = checkClusterName(clusterName) if err != nil { @@ -273,6 +278,14 @@ func CreateCmd(h *internal.Helper) *cobra.Command { } } + if publicEndpointDisabled { + v1Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{ + Public: &serverlessModel.EndpointsPublic{ + Disabled: publicEndpointDisabled, + }, + } + } + if h.IOStreams.CanPrompt { err := CreateAndSpinnerWait(ctx, d, v1Cluster, h) if err != nil { @@ -294,6 +307,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command { createCmd.Flags().StringP(flag.ProjectID, flag.ProjectIDShort, "", "The ID of the project, in which the cluster will be created. (default: \"default project\")") createCmd.Flags().Int32(flag.SpendingLimitMonthly, 0, "Maximum monthly spending limit in USD cents. (optional)") createCmd.Flags().Bool(flag.Encryption, false, "Whether Enhanced Encryption at Rest is enabled. (optional)") + createCmd.Flags().Bool(flag.PublicEndpointDisabled, false, "Whether the public endpoint is disabled. (optional)") return createCmd } diff --git a/internal/cli/serverless/update.go b/internal/cli/serverless/update.go index 4e099fad..b6cd0f62 100644 --- a/internal/cli/serverless/update.go +++ b/internal/cli/serverless/update.go @@ -26,6 +26,7 @@ import ( "tidbcloud-cli/internal/ui" "tidbcloud-cli/internal/util" serverlessApi "tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/client/serverless_service" + serverlessModel "tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/models" "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" @@ -50,9 +51,10 @@ func (c UpdateOpts) NonInteractiveFlags() []string { type mutableField string const ( - DisplayName mutableField = "displayName" - Annotations mutableField = "annotations" - Labels mutableField = "labels" + DisplayName mutableField = "displayName" + Annotations mutableField = "annotations" + Labels mutableField = "labels" + PublicEndpointDisabled mutableField = "endpoints.public.disabled" ) var mutableFields = []string{ @@ -94,8 +96,8 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { if err != nil { return err } - cmd.MarkFlagsMutuallyExclusive(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels) - cmd.MarkFlagsOneRequired(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels) + cmd.MarkFlagsMutuallyExclusive(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, flag.PublicEndpointDisabled) + cmd.MarkFlagsOneRequired(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, flag.PublicEndpointDisabled) } return nil }, @@ -109,6 +111,7 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { var clusterID string var fieldName string var displayName, labels, annotations string + var publicEndpointDisabled bool if opts.interactive { cmd.Annotations[telemetry.InteractiveMode] = "true" if !h.IOStreams.CanPrompt { @@ -168,6 +171,10 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { if err != nil { return errors.Trace(err) } + publicEndpointDisabled, err = cmd.Flags().GetBool(flag.PublicEndpointDisabled) + if err != nil { + return errors.Trace(err) + } } body := &serverlessApi.ServerlessServicePartialUpdateClusterBody{ @@ -193,8 +200,15 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { body.Cluster.Annotations = annotationsMap fieldName = string(Annotations) } + if cmd.Flags().Changed(flag.PublicEndpointDisabled) { + body.Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{ + Public: &serverlessModel.EndpointsPublic{ + Disabled: publicEndpointDisabled, + }, + } + fieldName = string(PublicEndpointDisabled) + } body.UpdateMask = &fieldName - params := serverlessApi.NewServerlessServicePartialUpdateClusterParams().WithClusterClusterID(clusterID). WithBody(*body).WithContext(ctx) _, err = d.PartialUpdateCluster(params) @@ -210,6 +224,7 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { updateCmd.Flags().StringP(flag.DisplayName, flag.DisplayNameShort, "", "The new displayName of the cluster to be updated.") updateCmd.Flags().String(flag.ServerlessLabels, "", "The labels of the cluster to be added or updated.\nInteractive example: {\"label1\":\"value1\",\"label2\":\"value2\"}.\nNonInteractive example: \"{\\\"label1\\\":\\\"value1\\\",\\\"label2\\\":\\\"value2\\\"}\".") updateCmd.Flags().String(flag.ServerlessAnnotations, "", "The annotations of the cluster to be added or updated.\nInteractive example: {\"annotation1\":\"value1\",\"annotation2\":\"value2\"}.\nNonInteractive example: \"{\\\"annotation1\\\":\\\"value1\\\",\\\"annotation2\\\":\\\"value2\\\"}\".") + updateCmd.Flags().Bool(flag.PublicEndpointDisabled, false, "Disable the public endpoint of the cluster.") return updateCmd } diff --git a/internal/flag/flag.go b/internal/flag/flag.go index b590a6ae..9eedc35f 100644 --- a/internal/flag/flag.go +++ b/internal/flag/flag.go @@ -15,69 +15,70 @@ package flag const ( - ClusterID string = "cluster-id" - ClusterIDShort string = "c" - LocalConcurrency string = "local.concurrency" - CSVBackslashEscape string = "csv.backslash-escape" - CSVDelimiter string = "csv.delimiter" - CSVSeparator string = "csv.separator" - CSVTrimLastSeparator string = "csv.trim-last-separator" - CSVNullValue string = "csv.null-value" - CSVSkipHeader string = "csv.skip-header" - DisplayName string = "display-name" - DisplayNameShort string = "n" - ClusterType string = "cluster-type" - BranchID string = "branch-id" - BranchIDShort string = "b" - Debug string = "debug" - DebugShort string = "D" - LocalFilePath string = "local.file-path" - Force string = "force" - ImportID string = "import-id" - NoColor string = "no-color" - Output string = "output" - OutputShort string = "o" - Password string = "password" - ProjectID string = "project-id" - ProjectIDShort string = "p" - ProfileName string = "profile-name" - Profile string = "profile" - ProfileShort string = "P" - PublicKey string = "public-key" - PrivateKey string = "private-key" - Query string = "query" - QueryShort string = "q" - Region string = "region" - RegionShort string = "r" - LocalTargetDatabase string = "local.target-database" - LocalTargetTable string = "local.target-table" - User string = "user" - UserShort string = "u" - SpendingLimitMonthly string = "spending-limit-monthly" - ServerlessLabels string = "labels" - ServerlessAnnotations string = "annotations" - Monthly string = "monthly" - BackupID string = "backup-id" - BackupTime string = "backup-time" - S3URI string = "s3.uri" - S3AccessKeyID string = "s3.access-key-id" - S3SecretAccessKey string = "s3.secret-access-key" - TargetType string = "target-type" - FileType string = "file-type" - ExportID string = "export-id" - ExportIDShort string = "e" - OutputPath string = "output-path" - Encryption string = "encryption" - Compression string = "compression" - SourceType string = "source-type" - UserRole string = "role" - AddRole string = "add-role" - DeleteRole string = "delete-role" - Concurrency string = "concurrency" - SQL string = "sql" - TableWhere string = "where" - TableFilter string = "filter" - ParentID string = "parent-id" + ClusterID string = "cluster-id" + ClusterIDShort string = "c" + LocalConcurrency string = "local.concurrency" + CSVBackslashEscape string = "csv.backslash-escape" + CSVDelimiter string = "csv.delimiter" + CSVSeparator string = "csv.separator" + CSVTrimLastSeparator string = "csv.trim-last-separator" + CSVNullValue string = "csv.null-value" + CSVSkipHeader string = "csv.skip-header" + DisplayName string = "display-name" + DisplayNameShort string = "n" + ClusterType string = "cluster-type" + BranchID string = "branch-id" + BranchIDShort string = "b" + Debug string = "debug" + DebugShort string = "D" + LocalFilePath string = "local.file-path" + Force string = "force" + ImportID string = "import-id" + NoColor string = "no-color" + Output string = "output" + OutputShort string = "o" + Password string = "password" + ProjectID string = "project-id" + ProjectIDShort string = "p" + ProfileName string = "profile-name" + Profile string = "profile" + ProfileShort string = "P" + PublicKey string = "public-key" + PrivateKey string = "private-key" + Query string = "query" + QueryShort string = "q" + Region string = "region" + RegionShort string = "r" + LocalTargetDatabase string = "local.target-database" + LocalTargetTable string = "local.target-table" + User string = "user" + UserShort string = "u" + SpendingLimitMonthly string = "spending-limit-monthly" + ServerlessLabels string = "labels" + ServerlessAnnotations string = "annotations" + Monthly string = "monthly" + BackupID string = "backup-id" + BackupTime string = "backup-time" + S3URI string = "s3.uri" + S3AccessKeyID string = "s3.access-key-id" + S3SecretAccessKey string = "s3.secret-access-key" + TargetType string = "target-type" + FileType string = "file-type" + ExportID string = "export-id" + ExportIDShort string = "e" + OutputPath string = "output-path" + Encryption string = "encryption" + Compression string = "compression" + SourceType string = "source-type" + UserRole string = "role" + AddRole string = "add-role" + DeleteRole string = "delete-role" + Concurrency string = "concurrency" + SQL string = "sql" + TableWhere string = "where" + TableFilter string = "filter" + ParentID string = "parent-id" + PublicEndpointDisabled string = "public-endpoint.disabled" ) const OutputHelp = "Output format, one of [\"human\" \"json\"]. For the complete result, please use json format." From 63c9bfdd6cc6f3598c56fccfa1e0a6e97233f4f2 Mon Sep 17 00:00:00 2001 From: FingerLeader Date: Thu, 1 Aug 2024 17:54:29 +0800 Subject: [PATCH 2/3] add public endpoint disabled in interactive mode Signed-off-by: FingerLeader --- internal/cli/serverless/update.go | 42 +++++++++++++++++++------------ internal/service/cloud/logic.go | 27 ++++++++++++++++++++ 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/internal/cli/serverless/update.go b/internal/cli/serverless/update.go index b6cd0f62..9634fa0e 100644 --- a/internal/cli/serverless/update.go +++ b/internal/cli/serverless/update.go @@ -61,6 +61,7 @@ var mutableFields = []string{ string(DisplayName), string(Labels), string(Annotations), + string(PublicEndpointDisabled), } func UpdateCmd(h *internal.Helper) *cobra.Command { @@ -135,22 +136,30 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { return err } - // variables for input - inputModel, err := GetUpdateClusterInput() - if err != nil { - return err - } - fieldValue := inputModel.(ui.TextInputModel).Inputs[0].Value() + if fieldName == string(PublicEndpointDisabled) { + publicEndpointDisabled, err = cloud.GetSelectedBool("Disable the public endpoint of the cluster?") + if err != nil { + return err + } + } else { + // variables for input + inputModel, err := GetUpdateClusterInput() + if err != nil { + return err + } + + fieldValue := inputModel.(ui.TextInputModel).Inputs[0].Value() - switch fieldName { - case string(DisplayName): - displayName = fieldValue - case string(Annotations): - annotations = fieldValue - case string(Labels): - labels = fieldValue - default: - return errors.Errorf("invalid field %s", fieldName) + switch fieldName { + case string(DisplayName): + displayName = fieldValue + case string(Annotations): + annotations = fieldValue + case string(Labels): + labels = fieldValue + default: + return errors.Errorf("invalid field %s", fieldName) + } } } else { @@ -200,7 +209,8 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { body.Cluster.Annotations = annotationsMap fieldName = string(Annotations) } - if cmd.Flags().Changed(flag.PublicEndpointDisabled) { + // if filedName is PublicEndpointDisabled, means this field is changed in Interactive mode + if cmd.Flags().Changed(flag.PublicEndpointDisabled) || fieldName == string(PublicEndpointDisabled) { body.Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{ Public: &serverlessModel.EndpointsPublic{ Disabled: publicEndpointDisabled, diff --git a/internal/service/cloud/logic.go b/internal/service/cloud/logic.go index cf871ff6..d7bec579 100644 --- a/internal/service/cloud/logic.go +++ b/internal/service/cloud/logic.go @@ -236,6 +236,33 @@ func GetSelectedField(mutableFields []string) (string, error) { return field.(string), nil } +func GetSelectedBool(notice string) (bool, error) { + items := []interface{}{ + "true", + "false", + } + + model, err := ui.InitialSelectModel(items, notice) + if err != nil { + return false, errors.Trace(err) + } + + model.EnableFilter() + p := tea.NewProgram(model) + bModel, err := p.Run() + if err != nil { + return false, errors.Trace(err) + } + if m, _ := bModel.(ui.SelectModel); m.Interrupted { + return false, util.InterruptError + } + value := bModel.(ui.SelectModel).GetSelectedItem() + if value == nil { + return false, errors.New("no value selected") + } + return value.(string) == "true", nil +} + func GetSpendingLimitField(mutableFields []string) (string, error) { var items = make([]interface{}, 0, len(mutableFields)) for _, item := range mutableFields { From ccb5f76bfe519f4ae5aa5ba69cc3213911a95f28 Mon Sep 17 00:00:00 2001 From: FingerLeader Date: Fri, 2 Aug 2024 14:12:22 +0800 Subject: [PATCH 3/3] fix Signed-off-by: FingerLeader --- internal/cli/serverless/create.go | 1 + internal/cli/serverless/update.go | 13 +++++++++---- internal/flag/flag.go | 2 +- internal/service/cloud/logic.go | 1 - 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/cli/serverless/create.go b/internal/cli/serverless/create.go index 9fddbf1b..5de3d969 100644 --- a/internal/cli/serverless/create.go +++ b/internal/cli/serverless/create.go @@ -62,6 +62,7 @@ func (c CreateOpts) NonInteractiveFlags() []string { flag.ProjectID, flag.SpendingLimitMonthly, flag.Encryption, + flag.PublicEndpointDisabled, } } diff --git a/internal/cli/serverless/update.go b/internal/cli/serverless/update.go index 9634fa0e..9ef2c02f 100644 --- a/internal/cli/serverless/update.go +++ b/internal/cli/serverless/update.go @@ -45,6 +45,7 @@ func (c UpdateOpts) NonInteractiveFlags() []string { flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, + flag.PublicEndpointDisabled, } } @@ -57,11 +58,15 @@ const ( PublicEndpointDisabled mutableField = "endpoints.public.disabled" ) +const ( + PublicEndpointDisabledHumanReadable = "disable public endpoint" +) + var mutableFields = []string{ string(DisplayName), string(Labels), string(Annotations), - string(PublicEndpointDisabled), + string(PublicEndpointDisabledHumanReadable), } func UpdateCmd(h *internal.Helper) *cobra.Command { @@ -136,7 +141,7 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { return err } - if fieldName == string(PublicEndpointDisabled) { + if fieldName == string(PublicEndpointDisabledHumanReadable) { publicEndpointDisabled, err = cloud.GetSelectedBool("Disable the public endpoint of the cluster?") if err != nil { return err @@ -209,8 +214,8 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { body.Cluster.Annotations = annotationsMap fieldName = string(Annotations) } - // if filedName is PublicEndpointDisabled, means this field is changed in Interactive mode - if cmd.Flags().Changed(flag.PublicEndpointDisabled) || fieldName == string(PublicEndpointDisabled) { + // if fieldName is PublicEndpointDisabled, means this field is changed in Interactive mode + if cmd.Flags().Changed(flag.PublicEndpointDisabled) || fieldName == string(PublicEndpointDisabledHumanReadable) { body.Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{ Public: &serverlessModel.EndpointsPublic{ Disabled: publicEndpointDisabled, diff --git a/internal/flag/flag.go b/internal/flag/flag.go index 9eedc35f..4c192837 100644 --- a/internal/flag/flag.go +++ b/internal/flag/flag.go @@ -78,7 +78,7 @@ const ( TableWhere string = "where" TableFilter string = "filter" ParentID string = "parent-id" - PublicEndpointDisabled string = "public-endpoint.disabled" + PublicEndpointDisabled string = "disable-public-endpoint" ) const OutputHelp = "Output format, one of [\"human\" \"json\"]. For the complete result, please use json format." diff --git a/internal/service/cloud/logic.go b/internal/service/cloud/logic.go index d7bec579..7c1e01ca 100644 --- a/internal/service/cloud/logic.go +++ b/internal/service/cloud/logic.go @@ -247,7 +247,6 @@ func GetSelectedBool(notice string) (bool, error) { return false, errors.Trace(err) } - model.EnableFilter() p := tea.NewProgram(model) bModel, err := p.Run() if err != nil {