From 2d0685295666300f242e327972e9e42d5bd9e4c4 Mon Sep 17 00:00:00 2001 From: FingerLeader <43462394+FingerLeader@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:28:02 +0800 Subject: [PATCH] add public endpoint disabled (#215) --- internal/cli/serverless/create.go | 15 ++++ internal/cli/serverless/update.go | 74 +++++++++++------ internal/flag/flag.go | 127 +++++++++++++++--------------- internal/service/cloud/logic.go | 26 ++++++ 4 files changed, 157 insertions(+), 85 deletions(-) diff --git a/internal/cli/serverless/create.go b/internal/cli/serverless/create.go index a984525d..5de3d969 100644 --- a/internal/cli/serverless/create.go +++ b/internal/cli/serverless/create.go @@ -62,6 +62,7 @@ func (c CreateOpts) NonInteractiveFlags() []string { flag.ProjectID, flag.SpendingLimitMonthly, flag.Encryption, + flag.PublicEndpointDisabled, } } @@ -124,6 +125,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command { var projectID string var spendingLimitMonthly int32 var encryption bool + var publicEndpointDisabled bool if opts.interactive { cmd.Annotations[telemetry.InteractiveMode] = "true" if !h.IOStreams.CanPrompt { @@ -243,6 +245,10 @@ func CreateCmd(h *internal.Helper) *cobra.Command { if err != nil { return errors.Trace(err) } + publicEndpointDisabled, err = cmd.Flags().GetBool(flag.PublicEndpointDisabled) + if err != nil { + return errors.Trace(err) + } // check clusterName err = checkClusterName(clusterName) if err != nil { @@ -273,6 +279,14 @@ func CreateCmd(h *internal.Helper) *cobra.Command { } } + if publicEndpointDisabled { + v1Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{ + Public: &serverlessModel.EndpointsPublic{ + Disabled: publicEndpointDisabled, + }, + } + } + if h.IOStreams.CanPrompt { err := CreateAndSpinnerWait(ctx, d, v1Cluster, h) if err != nil { @@ -294,6 +308,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command { createCmd.Flags().StringP(flag.ProjectID, flag.ProjectIDShort, "", "The ID of the project, in which the cluster will be created. (default: \"default project\")") createCmd.Flags().Int32(flag.SpendingLimitMonthly, 0, "Maximum monthly spending limit in USD cents. (optional)") createCmd.Flags().Bool(flag.Encryption, false, "Whether Enhanced Encryption at Rest is enabled. (optional)") + createCmd.Flags().Bool(flag.PublicEndpointDisabled, false, "Whether the public endpoint is disabled. (optional)") return createCmd } diff --git a/internal/cli/serverless/update.go b/internal/cli/serverless/update.go index 4e099fad..9ef2c02f 100644 --- a/internal/cli/serverless/update.go +++ b/internal/cli/serverless/update.go @@ -26,6 +26,7 @@ import ( "tidbcloud-cli/internal/ui" "tidbcloud-cli/internal/util" serverlessApi "tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/client/serverless_service" + serverlessModel "tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/models" "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" @@ -44,21 +45,28 @@ func (c UpdateOpts) NonInteractiveFlags() []string { flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, + flag.PublicEndpointDisabled, } } type mutableField string const ( - DisplayName mutableField = "displayName" - Annotations mutableField = "annotations" - Labels mutableField = "labels" + DisplayName mutableField = "displayName" + Annotations mutableField = "annotations" + Labels mutableField = "labels" + PublicEndpointDisabled mutableField = "endpoints.public.disabled" +) + +const ( + PublicEndpointDisabledHumanReadable = "disable public endpoint" ) var mutableFields = []string{ string(DisplayName), string(Labels), string(Annotations), + string(PublicEndpointDisabledHumanReadable), } func UpdateCmd(h *internal.Helper) *cobra.Command { @@ -94,8 +102,8 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { if err != nil { return err } - cmd.MarkFlagsMutuallyExclusive(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels) - cmd.MarkFlagsOneRequired(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels) + cmd.MarkFlagsMutuallyExclusive(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, flag.PublicEndpointDisabled) + cmd.MarkFlagsOneRequired(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, flag.PublicEndpointDisabled) } return nil }, @@ -109,6 +117,7 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { var clusterID string var fieldName string var displayName, labels, annotations string + var publicEndpointDisabled bool if opts.interactive { cmd.Annotations[telemetry.InteractiveMode] = "true" if !h.IOStreams.CanPrompt { @@ -132,22 +141,30 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { return err } - // variables for input - inputModel, err := GetUpdateClusterInput() - if err != nil { - return err - } - fieldValue := inputModel.(ui.TextInputModel).Inputs[0].Value() - - switch fieldName { - case string(DisplayName): - displayName = fieldValue - case string(Annotations): - annotations = fieldValue - case string(Labels): - labels = fieldValue - default: - return errors.Errorf("invalid field %s", fieldName) + if fieldName == string(PublicEndpointDisabledHumanReadable) { + publicEndpointDisabled, err = cloud.GetSelectedBool("Disable the public endpoint of the cluster?") + if err != nil { + return err + } + } else { + // variables for input + inputModel, err := GetUpdateClusterInput() + if err != nil { + return err + } + + fieldValue := inputModel.(ui.TextInputModel).Inputs[0].Value() + + switch fieldName { + case string(DisplayName): + displayName = fieldValue + case string(Annotations): + annotations = fieldValue + case string(Labels): + labels = fieldValue + default: + return errors.Errorf("invalid field %s", fieldName) + } } } else { @@ -168,6 +185,10 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { if err != nil { return errors.Trace(err) } + publicEndpointDisabled, err = cmd.Flags().GetBool(flag.PublicEndpointDisabled) + if err != nil { + return errors.Trace(err) + } } body := &serverlessApi.ServerlessServicePartialUpdateClusterBody{ @@ -193,8 +214,16 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { body.Cluster.Annotations = annotationsMap fieldName = string(Annotations) } + // if fieldName is PublicEndpointDisabled, means this field is changed in Interactive mode + if cmd.Flags().Changed(flag.PublicEndpointDisabled) || fieldName == string(PublicEndpointDisabledHumanReadable) { + body.Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{ + Public: &serverlessModel.EndpointsPublic{ + Disabled: publicEndpointDisabled, + }, + } + fieldName = string(PublicEndpointDisabled) + } body.UpdateMask = &fieldName - params := serverlessApi.NewServerlessServicePartialUpdateClusterParams().WithClusterClusterID(clusterID). WithBody(*body).WithContext(ctx) _, err = d.PartialUpdateCluster(params) @@ -210,6 +239,7 @@ func UpdateCmd(h *internal.Helper) *cobra.Command { updateCmd.Flags().StringP(flag.DisplayName, flag.DisplayNameShort, "", "The new displayName of the cluster to be updated.") updateCmd.Flags().String(flag.ServerlessLabels, "", "The labels of the cluster to be added or updated.\nInteractive example: {\"label1\":\"value1\",\"label2\":\"value2\"}.\nNonInteractive example: \"{\\\"label1\\\":\\\"value1\\\",\\\"label2\\\":\\\"value2\\\"}\".") updateCmd.Flags().String(flag.ServerlessAnnotations, "", "The annotations of the cluster to be added or updated.\nInteractive example: {\"annotation1\":\"value1\",\"annotation2\":\"value2\"}.\nNonInteractive example: \"{\\\"annotation1\\\":\\\"value1\\\",\\\"annotation2\\\":\\\"value2\\\"}\".") + updateCmd.Flags().Bool(flag.PublicEndpointDisabled, false, "Disable the public endpoint of the cluster.") return updateCmd } diff --git a/internal/flag/flag.go b/internal/flag/flag.go index b590a6ae..4c192837 100644 --- a/internal/flag/flag.go +++ b/internal/flag/flag.go @@ -15,69 +15,70 @@ package flag const ( - ClusterID string = "cluster-id" - ClusterIDShort string = "c" - LocalConcurrency string = "local.concurrency" - CSVBackslashEscape string = "csv.backslash-escape" - CSVDelimiter string = "csv.delimiter" - CSVSeparator string = "csv.separator" - CSVTrimLastSeparator string = "csv.trim-last-separator" - CSVNullValue string = "csv.null-value" - CSVSkipHeader string = "csv.skip-header" - DisplayName string = "display-name" - DisplayNameShort string = "n" - ClusterType string = "cluster-type" - BranchID string = "branch-id" - BranchIDShort string = "b" - Debug string = "debug" - DebugShort string = "D" - LocalFilePath string = "local.file-path" - Force string = "force" - ImportID string = "import-id" - NoColor string = "no-color" - Output string = "output" - OutputShort string = "o" - Password string = "password" - ProjectID string = "project-id" - ProjectIDShort string = "p" - ProfileName string = "profile-name" - Profile string = "profile" - ProfileShort string = "P" - PublicKey string = "public-key" - PrivateKey string = "private-key" - Query string = "query" - QueryShort string = "q" - Region string = "region" - RegionShort string = "r" - LocalTargetDatabase string = "local.target-database" - LocalTargetTable string = "local.target-table" - User string = "user" - UserShort string = "u" - SpendingLimitMonthly string = "spending-limit-monthly" - ServerlessLabels string = "labels" - ServerlessAnnotations string = "annotations" - Monthly string = "monthly" - BackupID string = "backup-id" - BackupTime string = "backup-time" - S3URI string = "s3.uri" - S3AccessKeyID string = "s3.access-key-id" - S3SecretAccessKey string = "s3.secret-access-key" - TargetType string = "target-type" - FileType string = "file-type" - ExportID string = "export-id" - ExportIDShort string = "e" - OutputPath string = "output-path" - Encryption string = "encryption" - Compression string = "compression" - SourceType string = "source-type" - UserRole string = "role" - AddRole string = "add-role" - DeleteRole string = "delete-role" - Concurrency string = "concurrency" - SQL string = "sql" - TableWhere string = "where" - TableFilter string = "filter" - ParentID string = "parent-id" + ClusterID string = "cluster-id" + ClusterIDShort string = "c" + LocalConcurrency string = "local.concurrency" + CSVBackslashEscape string = "csv.backslash-escape" + CSVDelimiter string = "csv.delimiter" + CSVSeparator string = "csv.separator" + CSVTrimLastSeparator string = "csv.trim-last-separator" + CSVNullValue string = "csv.null-value" + CSVSkipHeader string = "csv.skip-header" + DisplayName string = "display-name" + DisplayNameShort string = "n" + ClusterType string = "cluster-type" + BranchID string = "branch-id" + BranchIDShort string = "b" + Debug string = "debug" + DebugShort string = "D" + LocalFilePath string = "local.file-path" + Force string = "force" + ImportID string = "import-id" + NoColor string = "no-color" + Output string = "output" + OutputShort string = "o" + Password string = "password" + ProjectID string = "project-id" + ProjectIDShort string = "p" + ProfileName string = "profile-name" + Profile string = "profile" + ProfileShort string = "P" + PublicKey string = "public-key" + PrivateKey string = "private-key" + Query string = "query" + QueryShort string = "q" + Region string = "region" + RegionShort string = "r" + LocalTargetDatabase string = "local.target-database" + LocalTargetTable string = "local.target-table" + User string = "user" + UserShort string = "u" + SpendingLimitMonthly string = "spending-limit-monthly" + ServerlessLabels string = "labels" + ServerlessAnnotations string = "annotations" + Monthly string = "monthly" + BackupID string = "backup-id" + BackupTime string = "backup-time" + S3URI string = "s3.uri" + S3AccessKeyID string = "s3.access-key-id" + S3SecretAccessKey string = "s3.secret-access-key" + TargetType string = "target-type" + FileType string = "file-type" + ExportID string = "export-id" + ExportIDShort string = "e" + OutputPath string = "output-path" + Encryption string = "encryption" + Compression string = "compression" + SourceType string = "source-type" + UserRole string = "role" + AddRole string = "add-role" + DeleteRole string = "delete-role" + Concurrency string = "concurrency" + SQL string = "sql" + TableWhere string = "where" + TableFilter string = "filter" + ParentID string = "parent-id" + PublicEndpointDisabled string = "disable-public-endpoint" ) const OutputHelp = "Output format, one of [\"human\" \"json\"]. For the complete result, please use json format." diff --git a/internal/service/cloud/logic.go b/internal/service/cloud/logic.go index cf871ff6..7c1e01ca 100644 --- a/internal/service/cloud/logic.go +++ b/internal/service/cloud/logic.go @@ -236,6 +236,32 @@ func GetSelectedField(mutableFields []string) (string, error) { return field.(string), nil } +func GetSelectedBool(notice string) (bool, error) { + items := []interface{}{ + "true", + "false", + } + + model, err := ui.InitialSelectModel(items, notice) + if err != nil { + return false, errors.Trace(err) + } + + p := tea.NewProgram(model) + bModel, err := p.Run() + if err != nil { + return false, errors.Trace(err) + } + if m, _ := bModel.(ui.SelectModel); m.Interrupted { + return false, util.InterruptError + } + value := bModel.(ui.SelectModel).GetSelectedItem() + if value == nil { + return false, errors.New("no value selected") + } + return value.(string) == "true", nil +} + func GetSpendingLimitField(mutableFields []string) (string, error) { var items = make([]interface{}, 0, len(mutableFields)) for _, item := range mutableFields {