Skip to content

Commit

Permalink
add public endpoint disabled (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
FingerLeader authored Aug 2, 2024
1 parent 4cdd305 commit 2d06852
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 85 deletions.
15 changes: 15 additions & 0 deletions internal/cli/serverless/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func (c CreateOpts) NonInteractiveFlags() []string {
flag.ProjectID,
flag.SpendingLimitMonthly,
flag.Encryption,
flag.PublicEndpointDisabled,
}
}

Expand Down Expand Up @@ -124,6 +125,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command {
var projectID string
var spendingLimitMonthly int32
var encryption bool
var publicEndpointDisabled bool
if opts.interactive {
cmd.Annotations[telemetry.InteractiveMode] = "true"
if !h.IOStreams.CanPrompt {
Expand Down Expand Up @@ -243,6 +245,10 @@ func CreateCmd(h *internal.Helper) *cobra.Command {
if err != nil {
return errors.Trace(err)
}
publicEndpointDisabled, err = cmd.Flags().GetBool(flag.PublicEndpointDisabled)
if err != nil {
return errors.Trace(err)
}
// check clusterName
err = checkClusterName(clusterName)
if err != nil {
Expand Down Expand Up @@ -273,6 +279,14 @@ func CreateCmd(h *internal.Helper) *cobra.Command {
}
}

if publicEndpointDisabled {
v1Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{
Public: &serverlessModel.EndpointsPublic{
Disabled: publicEndpointDisabled,
},
}
}

if h.IOStreams.CanPrompt {
err := CreateAndSpinnerWait(ctx, d, v1Cluster, h)
if err != nil {
Expand All @@ -294,6 +308,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command {
createCmd.Flags().StringP(flag.ProjectID, flag.ProjectIDShort, "", "The ID of the project, in which the cluster will be created. (default: \"default project\")")
createCmd.Flags().Int32(flag.SpendingLimitMonthly, 0, "Maximum monthly spending limit in USD cents. (optional)")
createCmd.Flags().Bool(flag.Encryption, false, "Whether Enhanced Encryption at Rest is enabled. (optional)")
createCmd.Flags().Bool(flag.PublicEndpointDisabled, false, "Whether the public endpoint is disabled. (optional)")
return createCmd
}

Expand Down
74 changes: 52 additions & 22 deletions internal/cli/serverless/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"tidbcloud-cli/internal/ui"
"tidbcloud-cli/internal/util"
serverlessApi "tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/client/serverless_service"
serverlessModel "tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/models"

"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
Expand All @@ -44,21 +45,28 @@ func (c UpdateOpts) NonInteractiveFlags() []string {
flag.DisplayName,
flag.ServerlessAnnotations,
flag.ServerlessLabels,
flag.PublicEndpointDisabled,
}
}

type mutableField string

const (
DisplayName mutableField = "displayName"
Annotations mutableField = "annotations"
Labels mutableField = "labels"
DisplayName mutableField = "displayName"
Annotations mutableField = "annotations"
Labels mutableField = "labels"
PublicEndpointDisabled mutableField = "endpoints.public.disabled"
)

const (
PublicEndpointDisabledHumanReadable = "disable public endpoint"
)

var mutableFields = []string{
string(DisplayName),
string(Labels),
string(Annotations),
string(PublicEndpointDisabledHumanReadable),
}

func UpdateCmd(h *internal.Helper) *cobra.Command {
Expand Down Expand Up @@ -94,8 +102,8 @@ func UpdateCmd(h *internal.Helper) *cobra.Command {
if err != nil {
return err
}
cmd.MarkFlagsMutuallyExclusive(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels)
cmd.MarkFlagsOneRequired(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels)
cmd.MarkFlagsMutuallyExclusive(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, flag.PublicEndpointDisabled)
cmd.MarkFlagsOneRequired(flag.DisplayName, flag.ServerlessAnnotations, flag.ServerlessLabels, flag.PublicEndpointDisabled)
}
return nil
},
Expand All @@ -109,6 +117,7 @@ func UpdateCmd(h *internal.Helper) *cobra.Command {
var clusterID string
var fieldName string
var displayName, labels, annotations string
var publicEndpointDisabled bool
if opts.interactive {
cmd.Annotations[telemetry.InteractiveMode] = "true"
if !h.IOStreams.CanPrompt {
Expand All @@ -132,22 +141,30 @@ func UpdateCmd(h *internal.Helper) *cobra.Command {
return err
}

// variables for input
inputModel, err := GetUpdateClusterInput()
if err != nil {
return err
}
fieldValue := inputModel.(ui.TextInputModel).Inputs[0].Value()

switch fieldName {
case string(DisplayName):
displayName = fieldValue
case string(Annotations):
annotations = fieldValue
case string(Labels):
labels = fieldValue
default:
return errors.Errorf("invalid field %s", fieldName)
if fieldName == string(PublicEndpointDisabledHumanReadable) {
publicEndpointDisabled, err = cloud.GetSelectedBool("Disable the public endpoint of the cluster?")
if err != nil {
return err
}
} else {
// variables for input
inputModel, err := GetUpdateClusterInput()
if err != nil {
return err
}

fieldValue := inputModel.(ui.TextInputModel).Inputs[0].Value()

switch fieldName {
case string(DisplayName):
displayName = fieldValue
case string(Annotations):
annotations = fieldValue
case string(Labels):
labels = fieldValue
default:
return errors.Errorf("invalid field %s", fieldName)
}
}

} else {
Expand All @@ -168,6 +185,10 @@ func UpdateCmd(h *internal.Helper) *cobra.Command {
if err != nil {
return errors.Trace(err)
}
publicEndpointDisabled, err = cmd.Flags().GetBool(flag.PublicEndpointDisabled)
if err != nil {
return errors.Trace(err)
}
}

body := &serverlessApi.ServerlessServicePartialUpdateClusterBody{
Expand All @@ -193,8 +214,16 @@ func UpdateCmd(h *internal.Helper) *cobra.Command {
body.Cluster.Annotations = annotationsMap
fieldName = string(Annotations)
}
// if fieldName is PublicEndpointDisabled, means this field is changed in Interactive mode
if cmd.Flags().Changed(flag.PublicEndpointDisabled) || fieldName == string(PublicEndpointDisabledHumanReadable) {
body.Cluster.Endpoints = &serverlessModel.TidbCloudOpenApiserverlessv1beta1ClusterEndpoints{
Public: &serverlessModel.EndpointsPublic{
Disabled: publicEndpointDisabled,
},
}
fieldName = string(PublicEndpointDisabled)
}
body.UpdateMask = &fieldName

params := serverlessApi.NewServerlessServicePartialUpdateClusterParams().WithClusterClusterID(clusterID).
WithBody(*body).WithContext(ctx)
_, err = d.PartialUpdateCluster(params)
Expand All @@ -210,6 +239,7 @@ func UpdateCmd(h *internal.Helper) *cobra.Command {
updateCmd.Flags().StringP(flag.DisplayName, flag.DisplayNameShort, "", "The new displayName of the cluster to be updated.")
updateCmd.Flags().String(flag.ServerlessLabels, "", "The labels of the cluster to be added or updated.\nInteractive example: {\"label1\":\"value1\",\"label2\":\"value2\"}.\nNonInteractive example: \"{\\\"label1\\\":\\\"value1\\\",\\\"label2\\\":\\\"value2\\\"}\".")
updateCmd.Flags().String(flag.ServerlessAnnotations, "", "The annotations of the cluster to be added or updated.\nInteractive example: {\"annotation1\":\"value1\",\"annotation2\":\"value2\"}.\nNonInteractive example: \"{\\\"annotation1\\\":\\\"value1\\\",\\\"annotation2\\\":\\\"value2\\\"}\".")
updateCmd.Flags().Bool(flag.PublicEndpointDisabled, false, "Disable the public endpoint of the cluster.")
return updateCmd
}

Expand Down
127 changes: 64 additions & 63 deletions internal/flag/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,69 +15,70 @@
package flag

const (
ClusterID string = "cluster-id"
ClusterIDShort string = "c"
LocalConcurrency string = "local.concurrency"
CSVBackslashEscape string = "csv.backslash-escape"
CSVDelimiter string = "csv.delimiter"
CSVSeparator string = "csv.separator"
CSVTrimLastSeparator string = "csv.trim-last-separator"
CSVNullValue string = "csv.null-value"
CSVSkipHeader string = "csv.skip-header"
DisplayName string = "display-name"
DisplayNameShort string = "n"
ClusterType string = "cluster-type"
BranchID string = "branch-id"
BranchIDShort string = "b"
Debug string = "debug"
DebugShort string = "D"
LocalFilePath string = "local.file-path"
Force string = "force"
ImportID string = "import-id"
NoColor string = "no-color"
Output string = "output"
OutputShort string = "o"
Password string = "password"
ProjectID string = "project-id"
ProjectIDShort string = "p"
ProfileName string = "profile-name"
Profile string = "profile"
ProfileShort string = "P"
PublicKey string = "public-key"
PrivateKey string = "private-key"
Query string = "query"
QueryShort string = "q"
Region string = "region"
RegionShort string = "r"
LocalTargetDatabase string = "local.target-database"
LocalTargetTable string = "local.target-table"
User string = "user"
UserShort string = "u"
SpendingLimitMonthly string = "spending-limit-monthly"
ServerlessLabels string = "labels"
ServerlessAnnotations string = "annotations"
Monthly string = "monthly"
BackupID string = "backup-id"
BackupTime string = "backup-time"
S3URI string = "s3.uri"
S3AccessKeyID string = "s3.access-key-id"
S3SecretAccessKey string = "s3.secret-access-key"
TargetType string = "target-type"
FileType string = "file-type"
ExportID string = "export-id"
ExportIDShort string = "e"
OutputPath string = "output-path"
Encryption string = "encryption"
Compression string = "compression"
SourceType string = "source-type"
UserRole string = "role"
AddRole string = "add-role"
DeleteRole string = "delete-role"
Concurrency string = "concurrency"
SQL string = "sql"
TableWhere string = "where"
TableFilter string = "filter"
ParentID string = "parent-id"
ClusterID string = "cluster-id"
ClusterIDShort string = "c"
LocalConcurrency string = "local.concurrency"
CSVBackslashEscape string = "csv.backslash-escape"
CSVDelimiter string = "csv.delimiter"
CSVSeparator string = "csv.separator"
CSVTrimLastSeparator string = "csv.trim-last-separator"
CSVNullValue string = "csv.null-value"
CSVSkipHeader string = "csv.skip-header"
DisplayName string = "display-name"
DisplayNameShort string = "n"
ClusterType string = "cluster-type"
BranchID string = "branch-id"
BranchIDShort string = "b"
Debug string = "debug"
DebugShort string = "D"
LocalFilePath string = "local.file-path"
Force string = "force"
ImportID string = "import-id"
NoColor string = "no-color"
Output string = "output"
OutputShort string = "o"
Password string = "password"
ProjectID string = "project-id"
ProjectIDShort string = "p"
ProfileName string = "profile-name"
Profile string = "profile"
ProfileShort string = "P"
PublicKey string = "public-key"
PrivateKey string = "private-key"
Query string = "query"
QueryShort string = "q"
Region string = "region"
RegionShort string = "r"
LocalTargetDatabase string = "local.target-database"
LocalTargetTable string = "local.target-table"
User string = "user"
UserShort string = "u"
SpendingLimitMonthly string = "spending-limit-monthly"
ServerlessLabels string = "labels"
ServerlessAnnotations string = "annotations"
Monthly string = "monthly"
BackupID string = "backup-id"
BackupTime string = "backup-time"
S3URI string = "s3.uri"
S3AccessKeyID string = "s3.access-key-id"
S3SecretAccessKey string = "s3.secret-access-key"
TargetType string = "target-type"
FileType string = "file-type"
ExportID string = "export-id"
ExportIDShort string = "e"
OutputPath string = "output-path"
Encryption string = "encryption"
Compression string = "compression"
SourceType string = "source-type"
UserRole string = "role"
AddRole string = "add-role"
DeleteRole string = "delete-role"
Concurrency string = "concurrency"
SQL string = "sql"
TableWhere string = "where"
TableFilter string = "filter"
ParentID string = "parent-id"
PublicEndpointDisabled string = "disable-public-endpoint"
)

const OutputHelp = "Output format, one of [\"human\" \"json\"]. For the complete result, please use json format."
26 changes: 26 additions & 0 deletions internal/service/cloud/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,32 @@ func GetSelectedField(mutableFields []string) (string, error) {
return field.(string), nil
}

func GetSelectedBool(notice string) (bool, error) {
items := []interface{}{
"true",
"false",
}

model, err := ui.InitialSelectModel(items, notice)
if err != nil {
return false, errors.Trace(err)
}

p := tea.NewProgram(model)
bModel, err := p.Run()
if err != nil {
return false, errors.Trace(err)
}
if m, _ := bModel.(ui.SelectModel); m.Interrupted {
return false, util.InterruptError
}
value := bModel.(ui.SelectModel).GetSelectedItem()
if value == nil {
return false, errors.New("no value selected")
}
return value.(string) == "true", nil
}

func GetSpendingLimitField(mutableFields []string) (string, error) {
var items = make([]interface{}, 0, len(mutableFields))
for _, item := range mutableFields {
Expand Down

0 comments on commit 2d06852

Please sign in to comment.