diff --git a/compute/go.mod b/compute/go.mod index f01f029f..e58d6b05 100644 --- a/compute/go.mod +++ b/compute/go.mod @@ -6,7 +6,8 @@ replace github.com/databricks/databricks-sdk-go/databricks => ../databricks require ( github.com/databricks/databricks-sdk-go/databricks v0.0.0-00010101000000-000000000000 - golang.org/x/oauth2 v0.25.0 + github.com/stretchr/testify v1.10.0 + golang.org/x/mod v0.22.0 ) require ( @@ -14,6 +15,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect github.com/databricks/databricks-sdk-go v0.55.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -21,14 +23,15 @@ require ( github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect go.opentelemetry.io/otel v1.31.0 // indirect go.opentelemetry.io/otel/metric v1.31.0 // indirect go.opentelemetry.io/otel/trace v1.31.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect - golang.org/x/mod v0.22.0 // indirect golang.org/x/net v0.33.0 // indirect + golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.9.0 // indirect @@ -37,4 +40,5 @@ require ( google.golang.org/grpc v1.69.2 // indirect google.golang.org/protobuf v1.36.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/compute/go.sum b/compute/go.sum index 99fafa3e..3c3a17e2 100644 --- a/compute/go.sum +++ b/compute/go.sum @@ -75,6 +75,8 @@ google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU= google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4= google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/compute/v2/ext_commands.go b/compute/v2/ext_commands.go new file mode 100644 index 00000000..5a61efa0 --- /dev/null +++ b/compute/v2/ext_commands.go @@ -0,0 +1,46 @@ +// TODO : Add the missing methods and implement the methods +// This file has not been completely shifted from the SDK-Beta +// as we still don't have the wait for state methods in the SDK-mod +package compute + +import ( + "context" +) + +type CommandExecutorV2 struct { + clustersAPI *ClustersAPI + executionAPI *CommandExecutionAPI + language Language + clusterID string + contextID string +} + +type commandExecutionAPIUtilities interface { + Start(ctx context.Context, clusterID string, language Language) (*CommandExecutorV2, error) +} + +// Start the command execution context on a cluster and ensure it transitions to a running state +func (c *CommandExecutorV2) Destroy(ctx context.Context) error { + return c.executionAPI.Destroy(ctx, DestroyContext{ + ClusterId: c.clusterID, + ContextId: c.contextID, + }) +} + +// CommandExecutor creates a spark context and executes a command and then closes context +type CommandExecutor interface { + Execute(ctx context.Context, clusterID, language, commandStr string) Results +} + +// CommandMock mocks the execution of command +type CommandMock func(commandStr string) Results + +func (m CommandMock) Execute(_ context.Context, _, _, commandStr string) Results { + return m(commandStr) +} + +// CommandsHighLevelAPI exposes more friendly wrapper over command execution +type CommandsHighLevelAPI struct { + clusters *ClustersAPI + execution *CommandExecutionAPI +} diff --git a/compute/v2/ext_leading_whitespace.go b/compute/v2/ext_leading_whitespace.go new file mode 100644 index 00000000..909b908e --- /dev/null +++ b/compute/v2/ext_leading_whitespace.go @@ -0,0 +1,36 @@ +package compute + +import ( + "strings" +) + +// TrimLeadingWhitespace removes leading whitespace, so that Python code blocks +// that are embedded into Go code still could be interpreted properly. +func TrimLeadingWhitespace(commandStr string) (newCommand string) { + lines := strings.Split(strings.ReplaceAll(commandStr, "\t", " "), "\n") + leadingWhitespace := 1<<31 - 1 + for _, line := range lines { + for pos, char := range line { + if char == ' ' || char == '\t' { + continue + } + // first non-whitespace character + if pos < leadingWhitespace { + leadingWhitespace = pos + } + // is not needed further + break + } + } + for i := 0; i < len(lines); i++ { + if lines[i] == "" || strings.Trim(lines[i], " \t") == "" { + continue + } + if len(lines[i]) < leadingWhitespace { + newCommand += lines[i] + "\n" // or not.. + } else { + newCommand += lines[i][leadingWhitespace:] + "\n" + } + } + return +} diff --git a/compute/v2/ext_leading_whitespace_test.go b/compute/v2/ext_leading_whitespace_test.go new file mode 100644 index 00000000..e206cbdc --- /dev/null +++ b/compute/v2/ext_leading_whitespace_test.go @@ -0,0 +1,16 @@ +package compute + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTrimLeadingWhitespace(t *testing.T) { + assert.Equal(t, "foo\nbar\n", TrimLeadingWhitespace(` + + foo + bar + + `)) +} diff --git a/compute/v2/ext_library_utilities.go b/compute/v2/ext_library_utilities.go new file mode 100644 index 00000000..512e3379 --- /dev/null +++ b/compute/v2/ext_library_utilities.go @@ -0,0 +1,243 @@ +package compute + +import ( + "context" + "fmt" + "sort" + "strings" + "time" + + "github.com/databricks/databricks-sdk-go/databricks/apierr" + "github.com/databricks/databricks-sdk-go/databricks/log" + "github.com/databricks/databricks-sdk-go/databricks/retries" + "github.com/databricks/databricks-sdk-go/databricks/useragent" +) + +type Wait struct { + ClusterID string + Libraries []Library + IsRunning bool + IsRefresh bool +} + +func (library Library) String() string { + if library.Whl != "" { + return fmt.Sprintf("whl:%s", library.Whl) + } + if library.Jar != "" { + return fmt.Sprintf("jar:%s", library.Jar) + } + if library.Pypi != nil && library.Pypi.Package != "" { + return fmt.Sprintf("pypi:%s%s", library.Pypi.Repo, library.Pypi.Package) + } + if library.Maven != nil && library.Maven.Coordinates != "" { + mvn := library.Maven + return fmt.Sprintf("mvn:%s%s%s", mvn.Repo, mvn.Coordinates, + strings.Join(mvn.Exclusions, "")) + } + if library.Egg != "" { + return fmt.Sprintf("egg:%s", library.Egg) + } + if library.Cran != nil && library.Cran.Package != "" { + return fmt.Sprintf("cran:%s%s", library.Cran.Repo, library.Cran.Package) + } + return "unknown" +} + +func (cll *InstallLibraries) Sort() { + sort.Slice(cll.Libraries, func(i, j int) bool { + return cll.Libraries[i].String() < cll.Libraries[j].String() + }) +} + +// ToLibraryList convert to envity for convenient comparison +func (cls ClusterLibraryStatuses) ToLibraryList() InstallLibraries { + cll := InstallLibraries{ClusterId: cls.ClusterId} + for _, lib := range cls.LibraryStatuses { + cll.Libraries = append(cll.Libraries, *lib.Library) + } + cll.Sort() + return cll +} + +func (w *Wait) IsNotInScope(lib *Library) bool { + // if we don't know concrete libraries + if len(w.Libraries) == 0 { + return false + } + // if we know concrete libraries + for _, v := range w.Libraries { + if v.String() == lib.String() { + return false + } + } + return true +} + +// IsRetryNeeded returns first bool if there needs to be retry. +// If there needs to be retry, error message will explain why. +// If retry does not need to happen and error is not nil - it failed. +func (cls ClusterLibraryStatuses) IsRetryNeeded(w Wait) (bool, error) { + pending := 0 + ready := 0 + errors := []string{} + for _, lib := range cls.LibraryStatuses { + if lib.IsLibraryForAllClusters { + continue + } + if w.IsNotInScope(lib.Library) { + continue + } + switch lib.Status { + // No action has yet been taken to install the library. This state should be very short lived. + case "PENDING": + pending++ + // Metadata necessary to install the library is being retrieved from the provided repository. + case "RESOLVING": + pending++ + // The library is actively being installed, either by adding resources to Spark + // or executing system commands inside the Spark nodes. + case "INSTALLING": + pending++ + // The library has been successfully installed. + case "INSTALLED": + ready++ + // Installation on a Databricks Runtime 7.0 or above cluster was skipped due to Scala version incompatibility. + case "SKIPPED": + ready++ + // The library has been marked for removal. Libraries can be removed only when clusters are restarted. + case "UNINSTALL_ON_RESTART": + ready++ + //Some step in installation failed. More information can be found in the messages field. + case "FAILED": + if w.IsRefresh { + // we're reading library list on a running cluster and some of the libs failed to install + continue + } + errors = append(errors, fmt.Sprintf("%s failed: %s", lib.Library, strings.Join(lib.Messages, ", "))) + continue + } + } + if pending > 0 { + return true, fmt.Errorf("%d libraries are ready, but there are still %d pending", ready, pending) + } + if len(errors) > 0 { + return false, fmt.Errorf("%s", strings.Join(errors, "\n")) + } + return false, nil +} + +type Update struct { + ClusterId string + // The libraries to install. + Install []Library + // The libraries to install. + Uninstall []Library +} + +type librariesAPIUtilities interface { + UpdateAndWait(ctx context.Context, update Update, options ...retries.Option[ClusterLibraryStatuses]) error +} + +func (a *LibrariesAPI) UpdateAndWait(ctx context.Context, update Update, + options ...retries.Option[ClusterLibraryStatuses]) error { + ctx = useragent.InContext(ctx, "sdk-feature", "update-libraries") + if len(update.Uninstall) > 0 { + err := a.Uninstall(ctx, UninstallLibraries{ + ClusterId: update.ClusterId, + Libraries: update.Uninstall, + }) + if err != nil { + return fmt.Errorf("uninstall: %w", err) + } + } + if len(update.Install) > 0 { + err := a.Install(ctx, InstallLibraries{ + ClusterId: update.ClusterId, + Libraries: update.Install, + }) + if err != nil { + return fmt.Errorf("install: %w", err) + } + } + // this helps to avoid erroring out when out-of-list library gets added to + // the cluster manually and thereforce fails the wait on error + scope := make([]Library, len(update.Install)+len(update.Uninstall)) + scope = append(scope, update.Install...) + scope = append(scope, update.Uninstall...) + _, err := a.Wait(ctx, Wait{ + ClusterID: update.ClusterId, + Libraries: scope, + IsRunning: true, + IsRefresh: false, + }, options...) + return err +} + +// clusterID string, timeout time.Duration, isActive bool, refresh bool +func (a *LibrariesAPI) Wait(ctx context.Context, wait Wait, + options ...retries.Option[ClusterLibraryStatuses]) (*ClusterLibraryStatuses, error) { + ctx = useragent.InContext(ctx, "sdk-feature", "wait-for-libraries") + i := retries.Info[ClusterLibraryStatuses]{Timeout: 30 * time.Minute} + for _, o := range options { + o(&i) + } + result, err := retries.Poll(ctx, i.Timeout, func() (*ClusterLibraryStatuses, *retries.Err) { + status, err := a.ClusterStatusByClusterId(ctx, wait.ClusterID) + if apierr.IsMissing(err) { + // eventual consistency error + return nil, retries.Continue(err) + } + for _, o := range options { + o(&retries.Info[ClusterLibraryStatuses]{ + Timeout: i.Timeout, + Info: status, + }) + } + if err != nil { + return nil, retries.Halt(err) + } + if !wait.IsRunning { + log.InfoContext(ctx, "Cluster %s is currently not running, so just returning list of %d libraries", + wait.ClusterID, len(status.LibraryStatuses)) + return status, nil + } + retry, err := status.IsRetryNeeded(wait) + if retry { + return status, retries.Continue(err) + } + if err != nil { + return status, retries.Halt(err) + } + return status, nil + }) + if err != nil { + return nil, err + } + if wait.IsRunning { + installed := []LibraryFullStatus{} + cleanup := UninstallLibraries{ + ClusterId: wait.ClusterID, + Libraries: []Library{}, + } + // cleanup libraries that failed to install + for _, v := range result.LibraryStatuses { + if v.Status == "FAILED" { + log.WarningContext(ctx, "Removing failed library %s from %s", + v.Library, wait.ClusterID) + cleanup.Libraries = append(cleanup.Libraries, *v.Library) + continue + } + installed = append(installed, v) + } + // and result contains only the libraries that were successfully installed + result.LibraryStatuses = installed + if len(cleanup.Libraries) > 0 { + err = a.Uninstall(ctx, cleanup) + if err != nil { + return nil, fmt.Errorf("cannot cleanup libraries: %w", err) + } + } + } + return result, nil +} diff --git a/compute/v2/ext_node_type.go b/compute/v2/ext_node_type.go new file mode 100644 index 00000000..ca23cfc2 --- /dev/null +++ b/compute/v2/ext_node_type.go @@ -0,0 +1,129 @@ +package compute + +import ( + "context" + "fmt" + "strings" +) + +// NodeTypeRequest is a wrapper for local filtering of node types +type NodeTypeRequest struct { + Id string `json:"id,omitempty"` + MinMemoryGB int32 `json:"min_memory_gb,omitempty"` + GBPerCore int32 `json:"gb_per_core,omitempty"` + MinCores int32 `json:"min_cores,omitempty"` + MinGPUs int32 `json:"min_gpus,omitempty"` + LocalDisk bool `json:"local_disk,omitempty"` + LocalDiskMinSize int32 `json:"local_disk_min_size,omitempty"` + Category string `json:"category,omitempty"` + PhotonWorkerCapable bool `json:"photon_worker_capable,omitempty"` + PhotonDriverCapable bool `json:"photon_driver_capable,omitempty"` + Graviton bool `json:"graviton,omitempty"` + IsIOCacheEnabled bool `json:"is_io_cache_enabled,omitempty"` + SupportPortForwarding bool `json:"support_port_forwarding,omitempty"` + Fleet bool `json:"fleet,omitempty"` +} + +// sort NodeTypes within this struct +func (ntl *ListNodeTypesResponse) sort() { + sortByChain(ntl.NodeTypes, func(i int) sortCmp { + var localDisks, localDiskSizeGB, localNVMeDisk, localNVMeDiskSizeGB int32 + if ntl.NodeTypes[i].NodeInstanceType != nil { + localDisks = int32(ntl.NodeTypes[i].NodeInstanceType.LocalDisks) + localNVMeDisk = int32(ntl.NodeTypes[i].NodeInstanceType.LocalNvmeDisks) + localDiskSizeGB = int32(ntl.NodeTypes[i].NodeInstanceType.LocalDiskSizeGb) + localNVMeDiskSizeGB = int32(ntl.NodeTypes[i].NodeInstanceType.LocalNvmeDiskSizeGb) + } + return sortChain{ + boolAsc(ntl.NodeTypes[i].IsDeprecated), + intAsc(ntl.NodeTypes[i].NumCores), + intAsc(ntl.NodeTypes[i].MemoryMb), + intAsc(localDisks), + intAsc(localDiskSizeGB), + intAsc(localNVMeDisk), + intAsc(localNVMeDiskSizeGB), + intAsc(ntl.NodeTypes[i].NumGpus), + strAsc(ntl.NodeTypes[i].InstanceTypeId), + } + }) +} + +func (nt NodeType) shouldBeSkipped() bool { + if nt.NodeInfo == nil { + return false + } + for _, st := range nt.NodeInfo.Status { + switch st { + case CloudProviderNodeStatusNotAvailableInRegion, CloudProviderNodeStatusNotEnabledOnSubscription: + return true + } + } + return false +} + +func (ntl *ListNodeTypesResponse) Smallest(r NodeTypeRequest) (string, error) { + // error is explicitly ingored here, because Azure returns + // apparently too big of a JSON for Go to parse + if len(ntl.NodeTypes) == 0 { + return "", fmt.Errorf("cannot determine smallest node type with empty response") + } + ntl.sort() + for _, nt := range ntl.NodeTypes { + if nt.shouldBeSkipped() { + continue + } + gbs := int32(nt.MemoryMb / 1024) + if r.Fleet != strings.Contains(nt.NodeTypeId, "-fleet.") { + continue + } + if r.MinMemoryGB > 0 && gbs < r.MinMemoryGB { + continue + } + if r.GBPerCore > 0 && (gbs/int32(nt.NumCores)) < r.GBPerCore { + continue + } + if r.MinCores > 0 && int32(nt.NumCores) < r.MinCores { + continue + } + if (r.MinGPUs > 0 && int32(nt.NumGpus) < r.MinGPUs) || (r.MinGPUs == 0 && nt.NumGpus > 0) { + continue + } + if (r.LocalDisk || r.LocalDiskMinSize > 0) && nt.NodeInstanceType != nil && + (nt.NodeInstanceType.LocalDisks < 1 && + nt.NodeInstanceType.LocalNvmeDisks < 1) { + continue + } + if r.LocalDiskMinSize > 0 && nt.NodeInstanceType != nil && + (int32(nt.NodeInstanceType.LocalDiskSizeGb)+int32(nt.NodeInstanceType.LocalNvmeDiskSizeGb)) < r.LocalDiskMinSize { + continue + } + if r.Category != "" && !strings.EqualFold(nt.Category, r.Category) { + continue + } + if r.IsIOCacheEnabled && !nt.IsIoCacheEnabled { + continue + } + if r.SupportPortForwarding && !nt.SupportPortForwarding { + continue + } + if r.PhotonDriverCapable && !nt.PhotonDriverCapable { + continue + } + if r.PhotonWorkerCapable && !nt.PhotonWorkerCapable { + continue + } + if nt.IsGraviton != r.Graviton { + continue + } + return nt.NodeTypeId, nil + } + return "", fmt.Errorf("cannot determine smallest node type") +} + +func (a *ClustersAPI) SelectNodeType(ctx context.Context, r NodeTypeRequest) (string, error) { + nodeTypes, err := a.ListNodeTypes(ctx) + if err != nil { + return "", err + } + return nodeTypes.Smallest(r) +} diff --git a/compute/v2/ext_node_type_test.go b/compute/v2/ext_node_type_test.go new file mode 100644 index 00000000..ca8c7131 --- /dev/null +++ b/compute/v2/ext_node_type_test.go @@ -0,0 +1,206 @@ +package compute + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNodeType(t *testing.T) { + lst := ListNodeTypesResponse{ + NodeTypes: []NodeType{ + { + NodeTypeId: "m-fleet.xlarge", + InstanceTypeId: "m-fleet.xlarge", + MemoryMb: 16384, + NumCores: 4, + }, + { + NodeTypeId: "Random_05", + MemoryMb: 1024, + NumCores: 32, + NodeInstanceType: &NodeInstanceType{ + LocalDisks: 3, + LocalDiskSizeGb: 100, + }, + }, + { + NodeTypeId: "Standard_L80s_v2", + MemoryMb: 655360, + NumCores: 80, + NodeInstanceType: &NodeInstanceType{ + LocalDisks: 2, + InstanceTypeId: "Standard_L80s_v2", + LocalDiskSizeGb: 160, + LocalNvmeDisks: 1, + }, + }, + { + NodeTypeId: "Random_01", + MemoryMb: 8192, + NumCores: 8, + NodeInstanceType: &NodeInstanceType{ + InstanceTypeId: "_", + }, + }, + { + NodeTypeId: "Random_02", + MemoryMb: 8192, + NumCores: 8, + NumGpus: 2, + NodeInstanceType: &NodeInstanceType{ + InstanceTypeId: "_", + }, + }, + { + NodeTypeId: "Random_03", + MemoryMb: 8192, + NumCores: 8, + NumGpus: 1, + NodeInstanceType: &NodeInstanceType{ + InstanceTypeId: "_", + LocalNvmeDisks: 15, + LocalNvmeDiskSizeGb: 235, + }, + }, + { + NodeTypeId: "Random_04", + MemoryMb: 32000, + NumCores: 32, + IsDeprecated: true, + NodeInstanceType: &NodeInstanceType{ + LocalDisks: 2, + LocalDiskSizeGb: 20, + }, + }, + { + NodeTypeId: "Standard_F4s", + MemoryMb: 8192, + NumCores: 4, + NodeInstanceType: &NodeInstanceType{ + LocalDisks: 1, + LocalDiskSizeGb: 16, + LocalNvmeDisks: 0, + }, + }, + }, + } + nt, err := lst.Smallest(NodeTypeRequest{LocalDiskMinSize: 200, MinMemoryGB: 8, MinCores: 8, MinGPUs: 1}) + assert.NoError(t, err) + assert.Equal(t, "Random_03", nt) +} + +func TestNodeTypeCategory(t *testing.T) { + lst := ListNodeTypesResponse{ + NodeTypes: []NodeType{ + { + NodeTypeId: "Random_05", + InstanceTypeId: "Random_05", + MemoryMb: 1024, + NumCores: 32, + NodeInstanceType: &NodeInstanceType{ + LocalDisks: 3, + LocalDiskSizeGb: 100, + }, + }, + { + NodeTypeId: "Random_01", + InstanceTypeId: "Random_01", + MemoryMb: 8192, + NumCores: 8, + NodeInstanceType: &NodeInstanceType{ + InstanceTypeId: "_", + }, + Category: "Memory Optimized", + }, + { + NodeTypeId: "Random_02", + InstanceTypeId: "Random_02", + MemoryMb: 8192, + NumCores: 8, + Category: "Storage Optimized", + }, + }, + } + nt, err := lst.Smallest(NodeTypeRequest{Category: "Storage optimized"}) + assert.NoError(t, err) + assert.Equal(t, "Random_02", nt) +} + +func TestNodeTypeCategoryNotAvailable(t *testing.T) { + lst := ListNodeTypesResponse{ + NodeTypes: []NodeType{ + { + NodeTypeId: "Random_05", + InstanceTypeId: "Random_05", + MemoryMb: 1024, + NumCores: 32, + NodeInstanceType: &NodeInstanceType{ + LocalDisks: 3, + LocalDiskSizeGb: 100, + }, + }, + { + NodeTypeId: "Random_01", + InstanceTypeId: "Random_01", + MemoryMb: 8192, + NumCores: 8, + NodeInstanceType: &NodeInstanceType{ + InstanceTypeId: "_", + }, + NodeInfo: &CloudProviderNodeInfo{ + Status: []CloudProviderNodeStatus{ + CloudProviderNodeStatusNotAvailableInRegion, + CloudProviderNodeStatusNotEnabledOnSubscription}, + }, + Category: "Storage Optimized", + }, + { + NodeTypeId: "Random_02", + InstanceTypeId: "Random_02", + MemoryMb: 8192, + NumCores: 8, + Category: "Storage Optimized", + }, + }, + } + nt, err := lst.Smallest(NodeTypeRequest{Category: "Storage optimized"}) + assert.NoError(t, err) + assert.Equal(t, "Random_02", nt) +} + +func TestNodeTypeFleet(t *testing.T) { + lst := ListNodeTypesResponse{ + NodeTypes: []NodeType{ + { + NodeTypeId: "Random_05", + InstanceTypeId: "Random_05", + MemoryMb: 1024, + NumCores: 4, + }, + { + NodeTypeId: "m-fleet.xlarge", + InstanceTypeId: "m-fleet.xlarge", + MemoryMb: 16384, + NumCores: 4, + }, + { + NodeTypeId: "m-fleet.2xlarge", + InstanceTypeId: "m-fleet.2xlarge", + MemoryMb: 32768, + NumCores: 8, + }, + }, + } + nt, err := lst.Smallest(NodeTypeRequest{Fleet: true, MinCores: 8}) + assert.NoError(t, err) + assert.Equal(t, "m-fleet.2xlarge", nt) +} + +func TestNodeTypeEmptyList(t *testing.T) { + lst := ListNodeTypesResponse{ + NodeTypes: []NodeType{}, + } + _, err := lst.Smallest(NodeTypeRequest{Fleet: true}) + assert.ErrorContains(t, err, "cannot determine smallest node type with empty response") +} diff --git a/compute/v2/ext_results.go b/compute/v2/ext_results.go new file mode 100644 index 00000000..f7ec9d50 --- /dev/null +++ b/compute/v2/ext_results.go @@ -0,0 +1,100 @@ +package compute + +import ( + "errors" + "html" + "regexp" + "strings" +) + +var ( + // IPython's output prefixes + outRE = regexp.MustCompile(`Out\[[\d\s]+\]:\s`) + // HTML tags + tagRE = regexp.MustCompile(`<[^>]*>`) + // just exception content without exception name + exceptionRE = regexp.MustCompile(`.*Exception:\s+(.*)`) + // execution errors resulting from http errors are sometimes hidden in these keys + executionErrorRE = regexp.MustCompile(`ExecutionError: ([\s\S]*)\n(StatusCode=[0-9]*)\n(StatusDescription=.*)\n`) + // usual error message explanation is hidden in this key + errorMessageRE = regexp.MustCompile(`ErrorMessage=(.+)\n`) +) + +// Failed tells if command execution failed +func (r *Results) Failed() bool { + return r.ResultType == "error" +} + +// Text returns plain text results +func (r *Results) Text() string { + if r.ResultType != "text" { + return "" + } + return outRE.ReplaceAllLiteralString(r.Data.(string), "") +} + +// Err returns error type +func (r *Results) Err() error { + if !r.Failed() { + return nil + } + return errors.New(r.Error()) +} + +// Error returns error in a bit more friendly way +func (r *Results) Error() string { + if r.ResultType != "error" { + return "" + } + summary := tagRE.ReplaceAllLiteralString(r.Summary, "") + summary = html.UnescapeString(summary) + + exceptionMatches := exceptionRE.FindStringSubmatch(summary) + if len(exceptionMatches) == 2 { + summary = strings.ReplaceAll(exceptionMatches[1], "; nested exception is:", "") + summary = strings.TrimRight(summary, " ") + return summary + } + + executionErrorMatches := executionErrorRE.FindStringSubmatch(r.Cause) + if len(executionErrorMatches) == 4 { + return strings.Join(executionErrorMatches[1:], "\n") + } + + errorMessageMatches := errorMessageRE.FindStringSubmatch(r.Cause) + if len(errorMessageMatches) == 2 { + return errorMessageMatches[1] + } + + return summary +} + +// Scan scans for results +// TODO: change API, also in terraform (databricks_sql_permissions) +// for now we're adding `pos` field artificially. this must be removed +// before this repo is public. +func (r *Results) Scan(dest ...any) bool { + if r.ResultType != ResultTypeTable { + return false + } + if rows, ok := r.Data.([]any); ok { + if r.Pos >= len(rows) { + return false + } + if cols, ok := rows[r.Pos].([]any); ok { + for i := range dest { + switch d := dest[i].(type) { + case *string: + *d = cols[i].(string) + case *int: + *d = cols[i].(int) + case *bool: + *d = cols[i].(bool) + } + } + r.Pos++ + return true + } + } + return false +} diff --git a/compute/v2/ext_results_test.go b/compute/v2/ext_results_test.go new file mode 100644 index 00000000..5f52536e --- /dev/null +++ b/compute/v2/ext_results_test.go @@ -0,0 +1,50 @@ +package compute + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResults_Error(t *testing.T) { + cr := Results{} + assert.NoError(t, cr.Err()) + cr.ResultType = "error" + assert.EqualError(t, cr.Err(), "") + + cr.Summary = "NotFoundException: Things are going wrong; nested exception is: with something" + assert.Equal(t, "Things are going wrong with something", cr.Error()) + + cr.Summary = "" + cr.Cause = "ExecutionError: \nStatusCode=400\nStatusDescription=ABC\nSomething else" + assert.Equal(t, "\nStatusCode=400\nStatusDescription=ABC", cr.Error()) + + cr.Cause = "ErrorMessage=Error was here\n" + assert.Equal(t, "Error was here", cr.Error()) + + assert.False(t, cr.Scan()) +} + +func TestResults_Scan(t *testing.T) { + cr := Results{ + ResultType: "table", + Data: []interface{}{ + []interface{}{"foo", 1, true}, + []interface{}{"bar", 2, false}, + }, + } + a := "" + b := 0 + c := false + assert.True(t, cr.Scan(&a, &b, &c)) + assert.Equal(t, "foo", a) + assert.Equal(t, 1, b) + assert.Equal(t, true, c) + + assert.True(t, cr.Scan(&a, &b, &c)) + assert.Equal(t, "bar", a) + assert.Equal(t, 2, b) + assert.Equal(t, false, c) + + assert.False(t, cr.Scan(&a, &b, &c)) +} diff --git a/compute/v2/ext_sort.go b/compute/v2/ext_sort.go new file mode 100644 index 00000000..3f6ec86c --- /dev/null +++ b/compute/v2/ext_sort.go @@ -0,0 +1,49 @@ +package compute + +import ( + "sort" +) + +// readable chained sorting helper +func sortByChain(s interface{}, fn func(int) sortCmp) { + sort.Slice(s, func(i, j int) bool { + return fn(i).Less(fn(j)) + }) +} + +type sortCmp interface { + Less(o sortCmp) bool +} + +type boolAsc bool + +func (b boolAsc) Less(o sortCmp) bool { + return bool(b) != bool(o.(boolAsc)) && !bool(b) +} + +type intAsc int + +func (ia intAsc) Less(o sortCmp) bool { + return int(ia) < int(o.(intAsc)) +} + +type strAsc string + +func (s strAsc) Less(o sortCmp) bool { + return string(s) < string(o.(strAsc)) +} + +type sortChain []sortCmp + +func (c sortChain) Less(other sortCmp) bool { + o := other.(sortChain) + for i := range c { + if c[i].Less(o[i]) { + return true + } + if o[i].Less(c[i]) { + break + } + } + return false +} diff --git a/compute/v2/ext_spark_version.go b/compute/v2/ext_spark_version.go new file mode 100644 index 00000000..88985570 --- /dev/null +++ b/compute/v2/ext_spark_version.go @@ -0,0 +1,104 @@ +package compute + +import ( + "context" + "fmt" + "regexp" + "sort" + "strings" + + "golang.org/x/mod/semver" +) + +// SparkVersionRequest - filtering request +type SparkVersionRequest struct { + Id string `json:"id,omitempty"` + LongTermSupport bool `json:"long_term_support,omitempty" tf:"optional,default:false"` + Beta bool `json:"beta,omitempty" tf:"optional,default:false,conflicts:long_term_support"` + Latest bool `json:"latest,omitempty" tf:"optional,default:true"` + ML bool `json:"ml,omitempty" tf:"optional,default:false"` + Genomics bool `json:"genomics,omitempty" tf:"optional,default:false"` + GPU bool `json:"gpu,omitempty" tf:"optional,default:false"` + Scala string `json:"scala,omitempty" tf:"optional,default:2.12"` + SparkVersion string `json:"spark_version,omitempty" tf:"optional,default:"` + Photon bool `json:"photon,omitempty" tf:"optional,default:false"` + Graviton bool `json:"graviton,omitempty" tf:"optional,default:false"` +} + +type sparkVersionsType []string + +func (s sparkVersionsType) Len() int { + return len(s) +} +func (s sparkVersionsType) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +var dbrVersionRegex = regexp.MustCompile(`^(\d+\.\d+)\.x-.*`) + +func extractDbrVersions(s string) string { + m := dbrVersionRegex.FindStringSubmatch(s) + if len(m) > 1 { + return m[1] + } + return s +} + +func (s sparkVersionsType) Less(i, j int) bool { + return semver.Compare("v"+extractDbrVersions(s[i]), "v"+extractDbrVersions(s[j])) > 0 +} + +func (sv GetSparkVersionsResponse) Select(req SparkVersionRequest) (string, error) { + var versions []string + for _, version := range sv.Versions { + if strings.Contains(version.Key, "-scala"+req.Scala) { + matches := ((!strings.Contains(version.Key, "apache-spark-")) && + (strings.Contains(version.Key, "-ml-") == req.ML) && + (strings.Contains(version.Key, "-hls-") == req.Genomics) && + (strings.Contains(version.Key, "-gpu-") == req.GPU) && + (strings.Contains(version.Key, "-photon-") == req.Photon) && + (strings.Contains(version.Key, "-aarch64-") == req.Graviton) && + (strings.Contains(version.Name, "Beta") == req.Beta)) + if matches && req.LongTermSupport { + matches = (matches && (strings.Contains(version.Name, "LTS") || strings.Contains(version.Key, "-esr-"))) + } + if matches && len(req.SparkVersion) > 0 { + matches = (matches && strings.Contains(version.Name, "Apache Spark "+req.SparkVersion)) + } + if matches { + versions = append(versions, version.Key) + } + } + } + if len(versions) < 1 { + return "", fmt.Errorf("spark versions query returned no results. Please change your search criteria and try again") + } else if len(versions) > 1 { + if req.Latest { + sort.Sort(sparkVersionsType(versions)) + } else { + return "", fmt.Errorf("spark versions query returned multiple results %#v. Please change your search criteria and try again", versions) + } + } + return versions[0], nil +} + +// SelectSparkVersion returns latest DBR version matching the request parameters. +// If there are multiple versions matching the request, it will error (if latest = false) +// or return the latest version. +// Possible parameters are: +// - LongTermSupport: LTS versions only +// - Beta: Beta versions only +// - ML: ML versions only +// - Genomics: Genomics versions only +// - GPU: GPU versions only +// - Scala: Scala version +// - SparkVersion: Apache Spark version +// - Photon: Photon versions only (deprecated) +// - Graviton: Graviton versions only (deprecated) +func (a *ClustersAPI) SelectSparkVersion(ctx context.Context, r SparkVersionRequest) (string, error) { + sv, err := a.SparkVersions(ctx) + if err != nil { + return "", err + } + return sv.Select(r) +} diff --git a/compute/v2/ext_utilities.go b/compute/v2/ext_utilities.go new file mode 100644 index 00000000..f72f079b --- /dev/null +++ b/compute/v2/ext_utilities.go @@ -0,0 +1,32 @@ +// TODO : Add the missing methods and implement the methods +// This file has not been completely shifted from the SDK-Beta +// as we still don't have the wait for state methods in the SDK-mod +package compute + +import ( + "strings" + "sync" +) + +type clustersAPIUtilities interface { +} + +// getOrCreateClusterMutex guards "mounting" cluster creation to prevent multiple +// redundant instances created at the same name. Compute package private property. +// https://github.com/databricks/terraform-provider-databricks/issues/445 +var getOrCreateClusterMutex sync.Mutex + +func (c *ClusterDetails) IsRunningOrResizing() bool { + return c.State == StateRunning || c.State == StateResizing +} + +// use mutex around starting a cluster by ID +var mu sync.Mutex + +func (a *ClustersAPI) isErrFailedToReach(err error) bool { + if err == nil { + return false + } + // TODO: get a bit better handling of these + return strings.HasPrefix(err.Error(), "failed to reach") +}