diff --git a/cmd/log.go b/cmd/log.go new file mode 100644 index 00000000..48dc7b13 --- /dev/null +++ b/cmd/log.go @@ -0,0 +1,62 @@ +package cmd + +import ( + "fmt" + "sort" + "strings" + + "github.com/ekristen/aws-nuke/v3/resources" + "github.com/fatih/color" +) + +var ( + ReasonSkip = *color.New(color.FgYellow) + ReasonError = *color.New(color.FgRed) + ReasonRemoveTriggered = *color.New(color.FgGreen) + ReasonWaitPending = *color.New(color.FgBlue) + ReasonSuccess = *color.New(color.FgGreen) +) + +var ( + ColorRegion = *color.New(color.Bold) + ColorResourceType = *color.New() + ColorResourceID = *color.New(color.Bold) + ColorResourceProperties = *color.New(color.Italic) +) + +// Format the resource properties in sorted order ready for printing. +// This ensures that multiple runs of aws-nuke produce stable output so +// that they can be compared with each other. +func Sorted(m map[string]string) string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + sorted := make([]string, 0, len(m)) + for k := range keys { + sorted = append(sorted, fmt.Sprintf("%s: \"%s\"", keys[k], m[keys[k]])) + } + return fmt.Sprintf("[%s]", strings.Join(sorted, ", ")) +} + +func Log(region *Region, resourceType string, r resources.Resource, c color.Color, msg string) { + ColorRegion.Printf("%s", region.Name) + fmt.Printf(" - ") + ColorResourceType.Print(resourceType) + fmt.Printf(" - ") + + rString, ok := r.(resources.LegacyStringer) + if ok { + ColorResourceID.Print(rString.String()) + fmt.Printf(" - ") + } + + rProp, ok := r.(resources.ResourcePropertyGetter) + if ok { + ColorResourceProperties.Print(Sorted(rProp.Properties())) + fmt.Printf(" - ") + } + + c.Printf("%s\n", msg) +} diff --git a/cmd/nuke.go b/cmd/nuke.go new file mode 100644 index 00000000..6ed507fe --- /dev/null +++ b/cmd/nuke.go @@ -0,0 +1,318 @@ +package cmd + +import ( + "fmt" + "time" + + "github.com/ekristen/aws-nuke/v3/pkg/awsutil" + config "github.com/ekristen/aws-nuke/v3/pkg/config" + "github.com/ekristen/aws-nuke/v3/pkg/types" + "github.com/ekristen/aws-nuke/v3/resources" + "github.com/sirupsen/logrus" +) + +type Nuke struct { + Parameters NukeParameters + Account awsutil.Account + Config *config.Nuke + + ResourceTypes types.Collection + + items Queue +} + +func NewNuke(params NukeParameters, account awsutil.Account) *Nuke { + n := Nuke{ + Parameters: params, + Account: account, + } + + return &n +} + +func (n *Nuke) Run() error { + var err error + + if n.Parameters.ForceSleep < 3 && n.Parameters.NoDryRun { + return fmt.Errorf("Value for --force-sleep cannot be less than 3 seconds if --no-dry-run is set. This is for your own protection.") + } + forceSleep := time.Duration(n.Parameters.ForceSleep) * time.Second + + fmt.Printf("aws-nuke version %s - %s - %s\n\n", BuildVersion, BuildDate, BuildHash) + + err = n.Config.ValidateAccount(n.Account.ID(), n.Account.Aliases()) + if err != nil { + return err + } + + fmt.Printf("Do you really want to nuke the account with "+ + "the ID %s and the alias '%s'?\n", n.Account.ID(), n.Account.Alias()) + if n.Parameters.Force { + fmt.Printf("Waiting %v before continuing.\n", forceSleep) + time.Sleep(forceSleep) + } else { + fmt.Printf("Do you want to continue? Enter account alias to continue.\n") + err = Prompt(n.Account.Alias()) + if err != nil { + return err + } + } + + err = n.Scan() + if err != nil { + return err + } + + if n.items.Count(ItemStateNew) == 0 { + fmt.Println("No resource to delete.") + return nil + } + + if !n.Parameters.NoDryRun { + fmt.Println("The above resources would be deleted with the supplied configuration. Provide --no-dry-run to actually destroy resources.") + return nil + } + + fmt.Printf("Do you really want to nuke these resources on the account with "+ + "the ID %s and the alias '%s'?\n", n.Account.ID(), n.Account.Alias()) + if n.Parameters.Force { + fmt.Printf("Waiting %v before continuing.\n", forceSleep) + time.Sleep(forceSleep) + } else { + fmt.Printf("Do you want to continue? Enter account alias to continue.\n") + err = Prompt(n.Account.Alias()) + if err != nil { + return err + } + } + + failCount := 0 + waitingCount := 0 + + for { + n.HandleQueue() + + if n.items.Count(ItemStatePending, ItemStateWaiting, ItemStateNew) == 0 && n.items.Count(ItemStateFailed) > 0 { + if failCount >= 2 { + logrus.Errorf("There are resources in failed state, but none are ready for deletion, anymore.") + fmt.Println() + + for _, item := range n.items { + if item.State != ItemStateFailed { + continue + } + + item.Print() + logrus.Error(item.Reason) + } + + return fmt.Errorf("failed") + } + + failCount = failCount + 1 + } else { + failCount = 0 + } + if n.Parameters.MaxWaitRetries != 0 && n.items.Count(ItemStateWaiting, ItemStatePending) > 0 && n.items.Count(ItemStateNew) == 0 { + if waitingCount >= n.Parameters.MaxWaitRetries { + return fmt.Errorf("Max wait retries of %d exceeded.\n\n", n.Parameters.MaxWaitRetries) + } + waitingCount = waitingCount + 1 + } else { + waitingCount = 0 + } + if n.items.Count(ItemStateNew, ItemStatePending, ItemStateFailed, ItemStateWaiting) == 0 { + break + } + + time.Sleep(5 * time.Second) + } + + fmt.Printf("Nuke complete: %d failed, %d skipped, %d finished.\n\n", + n.items.Count(ItemStateFailed), n.items.Count(ItemStateFiltered), n.items.Count(ItemStateFinished)) + + return nil +} + +func (n *Nuke) Scan() error { + accountConfig := n.Config.Accounts[n.Account.ID()] + + resourceTypes := ResolveResourceTypes( + resources.GetListerNames(), + resources.GetCloudControlMapping(), + []types.Collection{ + n.Parameters.Targets, + n.Config.ResourceTypes.Targets, + accountConfig.ResourceTypes.Targets, + }, + []types.Collection{ + n.Parameters.Excludes, + n.Config.ResourceTypes.Excludes, + accountConfig.ResourceTypes.Excludes, + }, + []types.Collection{ + n.Parameters.CloudControl, + n.Config.ResourceTypes.CloudControl, + accountConfig.ResourceTypes.CloudControl, + }, + ) + + queue := make(Queue, 0) + + for _, regionName := range n.Config.Regions { + region := NewRegion(regionName, n.Account.ResourceTypeToServiceType, n.Account.NewSession) + + items := Scan(region, resourceTypes) + for item := range items { + ffGetter, ok := item.Resource.(resources.FeatureFlagGetter) + if ok { + ffGetter.FeatureFlags(n.Config.FeatureFlags) + } + + queue = append(queue, item) + err := n.Filter(item) + if err != nil { + return err + } + + if item.State != ItemStateFiltered || !n.Parameters.Quiet { + item.Print() + } + } + } + + fmt.Printf("Scan complete: %d total, %d nukeable, %d filtered.\n\n", + queue.CountTotal(), queue.Count(ItemStateNew), queue.Count(ItemStateFiltered)) + + n.items = queue + + return nil +} + +func (n *Nuke) Filter(item *Item) error { + + checker, ok := item.Resource.(resources.Filter) + if ok { + err := checker.Filter() + if err != nil { + item.State = ItemStateFiltered + item.Reason = err.Error() + + // Not returning the error, since it could be because of a failed + // request to the API. We do not want to block the whole nuking, + // because of an issue on AWS side. + return nil + } + } + + accountFilters, err := n.Config.Filters(n.Account.ID()) + if err != nil { + return err + } + + itemFilters, ok := accountFilters[item.Type] + if !ok { + return nil + } + + for _, filter := range itemFilters { + prop, err := item.GetProperty(filter.Property) + if err != nil { + logrus.Warnf(err.Error()) + continue + } + match, err := filter.Match(prop) + if err != nil { + return err + } + + if IsTrue(filter.Invert) { + match = !match + } + + if match { + item.State = ItemStateFiltered + item.Reason = "filtered by config" + return nil + } + } + + return nil +} + +func (n *Nuke) HandleQueue() { + listCache := make(map[string]map[string][]resources.Resource) + + for _, item := range n.items { + switch item.State { + case ItemStateNew: + n.HandleRemove(item) + item.Print() + case ItemStateFailed: + n.HandleRemove(item) + n.HandleWait(item, listCache) + item.Print() + case ItemStatePending: + n.HandleWait(item, listCache) + item.State = ItemStateWaiting + item.Print() + case ItemStateWaiting: + n.HandleWait(item, listCache) + item.Print() + } + + } + + fmt.Println() + fmt.Printf("Removal requested: %d waiting, %d failed, %d skipped, %d finished\n\n", + n.items.Count(ItemStateWaiting, ItemStatePending), n.items.Count(ItemStateFailed), + n.items.Count(ItemStateFiltered), n.items.Count(ItemStateFinished)) +} + +func (n *Nuke) HandleRemove(item *Item) { + err := item.Resource.Remove() + if err != nil { + item.State = ItemStateFailed + item.Reason = err.Error() + return + } + + item.State = ItemStatePending + item.Reason = "" +} + +func (n *Nuke) HandleWait(item *Item, cache map[string]map[string][]resources.Resource) { + var err error + region := item.Region.Name + _, ok := cache[region] + if !ok { + cache[region] = map[string][]resources.Resource{} + } + left, ok := cache[region][item.Type] + if !ok { + left, err = item.List() + if err != nil { + item.State = ItemStateFailed + item.Reason = err.Error() + return + } + cache[region][item.Type] = left + } + + for _, r := range left { + if item.Equals(r) { + checker, ok := r.(resources.Filter) + if ok { + err := checker.Filter() + if err != nil { + break + } + } + + return + } + } + + item.State = ItemStateFinished + item.Reason = "" +} diff --git a/cmd/params.go b/cmd/params.go new file mode 100644 index 00000000..b1583aed --- /dev/null +++ b/cmd/params.go @@ -0,0 +1,29 @@ +package cmd + +import ( + "fmt" + "strings" +) + +type NukeParameters struct { + ConfigPath string + + Targets []string + Excludes []string + CloudControl []string + + NoDryRun bool + Force bool + ForceSleep int + Quiet bool + + MaxWaitRetries int +} + +func (p *NukeParameters) Validate() error { + if strings.TrimSpace(p.ConfigPath) == "" { + return fmt.Errorf("You have to specify the --config flag.\n") + } + + return nil +} diff --git a/cmd/queue.go b/cmd/queue.go new file mode 100644 index 00000000..962a1600 --- /dev/null +++ b/cmd/queue.go @@ -0,0 +1,122 @@ +package cmd + +import ( + "fmt" + + "github.com/ekristen/aws-nuke/v3/resources" +) + +type ItemState int + +// States of Items based on the latest request to AWS. +const ( + ItemStateNew ItemState = iota + ItemStatePending + ItemStateWaiting + ItemStateFailed + ItemStateFiltered + ItemStateFinished +) + +// An Item describes an actual AWS resource entity with the current state and +// some metadata. +type Item struct { + Resource resources.Resource + + State ItemState + Reason string + + Region *Region + Type string +} + +func (i *Item) Print() { + switch i.State { + case ItemStateNew: + Log(i.Region, i.Type, i.Resource, ReasonWaitPending, "would remove") + case ItemStatePending: + Log(i.Region, i.Type, i.Resource, ReasonWaitPending, "triggered remove") + case ItemStateWaiting: + Log(i.Region, i.Type, i.Resource, ReasonWaitPending, "waiting") + case ItemStateFailed: + Log(i.Region, i.Type, i.Resource, ReasonError, "failed") + case ItemStateFiltered: + Log(i.Region, i.Type, i.Resource, ReasonSkip, i.Reason) + case ItemStateFinished: + Log(i.Region, i.Type, i.Resource, ReasonSuccess, "removed") + } +} + +// List gets all resource items of the same resource type like the Item. +func (i *Item) List() ([]resources.Resource, error) { + lister := resources.GetLister(i.Type) + sess, err := i.Region.Session(i.Type) + if err != nil { + return nil, err + } + return lister(sess) +} + +func (i *Item) GetProperty(key string) (string, error) { + if key == "" { + stringer, ok := i.Resource.(resources.LegacyStringer) + if !ok { + return "", fmt.Errorf("%T does not support legacy IDs", i.Resource) + } + return stringer.String(), nil + } + + getter, ok := i.Resource.(resources.ResourcePropertyGetter) + if !ok { + return "", fmt.Errorf("%T does not support custom properties", i.Resource) + } + + return getter.Properties().Get(key), nil +} + +func (i *Item) Equals(o resources.Resource) bool { + iType := fmt.Sprintf("%T", i.Resource) + oType := fmt.Sprintf("%T", o) + if iType != oType { + return false + } + + iStringer, iOK := i.Resource.(resources.LegacyStringer) + oStringer, oOK := o.(resources.LegacyStringer) + if iOK != oOK { + return false + } + if iOK && oOK { + return iStringer.String() == oStringer.String() + } + + iGetter, iOK := i.Resource.(resources.ResourcePropertyGetter) + oGetter, oOK := o.(resources.ResourcePropertyGetter) + if iOK != oOK { + return false + } + if iOK && oOK { + return iGetter.Properties().Equals(oGetter.Properties()) + } + + return false +} + +type Queue []*Item + +func (q Queue) CountTotal() int { + return len(q) +} + +func (q Queue) Count(states ...ItemState) int { + count := 0 + for _, item := range q { + for _, state := range states { + if item.State == state { + count = count + 1 + break + } + } + } + return count +} diff --git a/cmd/region.go b/cmd/region.go new file mode 100644 index 00000000..f53d5cba --- /dev/null +++ b/cmd/region.go @@ -0,0 +1,62 @@ +package cmd + +import ( + "fmt" + "sync" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/ekristen/aws-nuke/v3/pkg/awsutil" +) + +// SessionFactory support for custom endpoints +type SessionFactory func(regionName, svcType string) (*session.Session, error) + +// ResourceTypeResolver returns the service type from the resourceType +type ResourceTypeResolver func(regionName, resourceType string) string + +type Region struct { + Name string + NewSession SessionFactory + ResTypeResolver ResourceTypeResolver + + cache map[string]*session.Session + lock *sync.RWMutex +} + +func NewRegion(name string, typeResolver ResourceTypeResolver, sessionFactory SessionFactory) *Region { + return &Region{ + Name: name, + NewSession: sessionFactory, + ResTypeResolver: typeResolver, + lock: &sync.RWMutex{}, + cache: make(map[string]*session.Session), + } +} + +func (region *Region) Session(resourceType string) (*session.Session, error) { + svcType := region.ResTypeResolver(region.Name, resourceType) + if svcType == "" { + return nil, awsutil.ErrSkipRequest(fmt.Sprintf( + "No service available in region '%s' to handle '%s'", + region.Name, resourceType)) + } + + // Need to read + region.lock.RLock() + sess := region.cache[svcType] + region.lock.RUnlock() + if sess != nil { + return sess, nil + } + + // Need to write: + region.lock.Lock() + sess, err := region.NewSession(region.Name, svcType) + if err != nil { + region.lock.Unlock() + return nil, err + } + region.cache[svcType] = sess + region.lock.Unlock() + return sess, nil +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 00000000..ad3c2d4b --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,186 @@ +package cmd + +import ( + "fmt" + "os" + "sort" + + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/ekristen/aws-nuke/v3/pkg/awsutil" + "github.com/ekristen/aws-nuke/v3/pkg/config" + "github.com/ekristen/aws-nuke/v3/resources" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +func NewRootCommand() *cobra.Command { + var ( + params NukeParameters + creds awsutil.Credentials + defaultRegion string + verbose bool + ) + + command := &cobra.Command{ + Use: "aws-nuke", + Short: "aws-nuke removes every resource from AWS", + Long: `A tool which removes every resource from an AWS account. Use it with caution, since it cannot distinguish between production and non-production.`, + } + + command.PreRun = func(cmd *cobra.Command, args []string) { + log.SetLevel(log.InfoLevel) + if verbose { + log.SetLevel(log.DebugLevel) + } + log.SetFormatter(&log.TextFormatter{ + EnvironmentOverrideColors: true, + }) + } + + command.RunE = func(cmd *cobra.Command, args []string) error { + var err error + + err = params.Validate() + if err != nil { + return err + } + + if !creds.HasKeys() && !creds.HasProfile() && defaultRegion != "" { + creds.AccessKeyID = os.Getenv("AWS_ACCESS_KEY_ID") + creds.SecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY") + } + err = creds.Validate() + if err != nil { + return err + } + + command.SilenceUsage = true + + config, err := config.Load(params.ConfigPath) + if err != nil { + log.Errorf("Failed to parse config file %s", params.ConfigPath) + return err + } + + if defaultRegion != "" { + awsutil.DefaultRegionID = defaultRegion + switch defaultRegion { + case endpoints.UsEast1RegionID, endpoints.UsEast2RegionID, endpoints.UsWest1RegionID, endpoints.UsWest2RegionID: + awsutil.DefaultAWSPartitionID = endpoints.AwsPartitionID + case endpoints.UsGovEast1RegionID, endpoints.UsGovWest1RegionID: + awsutil.DefaultAWSPartitionID = endpoints.AwsUsGovPartitionID + case endpoints.CnNorth1RegionID, endpoints.CnNorthwest1RegionID: + awsutil.DefaultAWSPartitionID = endpoints.AwsCnPartitionID + default: + if config.CustomEndpoints.GetRegion(defaultRegion) == nil { + err = fmt.Errorf("The custom region '%s' must be specified in the configuration 'endpoints'", defaultRegion) + log.Error(err.Error()) + return err + } + } + } + + account, err := awsutil.NewAccount(&creds, config.CustomEndpoints) + if err != nil { + return err + } + + n := NewNuke(params, *account) + + n.Config = config + + return n.Run() + } + + command.PersistentFlags().BoolVarP( + &verbose, "verbose", "v", false, + "Enables debug output.") + + command.PersistentFlags().StringVarP( + ¶ms.ConfigPath, "config", "c", "", + "(required) Path to the nuke config file.") + + command.PersistentFlags().StringVar( + &creds.Profile, "profile", "", + "Name of the AWS profile name for accessing the AWS API. "+ + "Cannot be used together with --access-key-id and --secret-access-key.") + command.PersistentFlags().StringVar( + &creds.AccessKeyID, "access-key-id", "", + "AWS access key ID for accessing the AWS API. "+ + "Must be used together with --secret-access-key. "+ + "Cannot be used together with --profile.") + command.PersistentFlags().StringVar( + &creds.SecretAccessKey, "secret-access-key", "", + "AWS secret access key for accessing the AWS API. "+ + "Must be used together with --access-key-id. "+ + "Cannot be used together with --profile.") + command.PersistentFlags().StringVar( + &creds.SessionToken, "session-token", "", + "AWS session token for accessing the AWS API. "+ + "Must be used together with --access-key-id and --secret-access-key. "+ + "Cannot be used together with --profile.") + command.PersistentFlags().StringVar( + &creds.AssumeRoleArn, "assume-role-arn", "", + "AWS IAM role arn to assume. "+ + "The credentials provided via --access-key-id or --profile must "+ + "be allowed to assume this role. ") + command.PersistentFlags().StringVar( + &defaultRegion, "default-region", "", + "Custom default region name.") + + command.PersistentFlags().StringSliceVarP( + ¶ms.Targets, "target", "t", []string{}, + "Limit nuking to certain resource types (eg IAMServerCertificate). "+ + "This flag can be used multiple times.") + command.PersistentFlags().StringSliceVarP( + ¶ms.Excludes, "exclude", "e", []string{}, + "Prevent nuking of certain resource types (eg IAMServerCertificate). "+ + "This flag can be used multiple times.") + command.PersistentFlags().StringSliceVar( + ¶ms.CloudControl, "cloud-control", []string{}, + "Nuke given resource via Cloud Control API. "+ + "If there is an old-style method for the same resource, the old-style one will not be executed. "+ + "Note that old-style and cloud-control filters are not compatible! "+ + "This flag can be used multiple times.") + command.PersistentFlags().BoolVar( + ¶ms.NoDryRun, "no-dry-run", false, + "If specified, it actually deletes found resources. "+ + "Otherwise it just lists all candidates.") + command.PersistentFlags().BoolVar( + ¶ms.Force, "force", false, + "Don't ask for confirmation before deleting resources. "+ + "Instead it waits 15s before continuing. Set --force-sleep to change the wait time.") + command.PersistentFlags().IntVar( + ¶ms.ForceSleep, "force-sleep", 15, + "If specified and --force is set, wait this many seconds before deleting resources. "+ + "Defaults to 15.") + command.PersistentFlags().IntVar( + ¶ms.MaxWaitRetries, "max-wait-retries", 0, + "If specified, the program will exit if resources are stuck in waiting for this many iterations. "+ + "0 (default) disables early exit.") + command.PersistentFlags().BoolVarP( + ¶ms.Quiet, "quiet", "q", false, + "Don't show filtered resources.") + + command.AddCommand(NewVersionCommand()) + command.AddCommand(NewResourceTypesCommand()) + + return command +} + +func NewResourceTypesCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "resource-types", + Short: "lists all available resource types", + Run: func(cmd *cobra.Command, args []string) { + names := resources.GetListerNames() + sort.Strings(names) + + for _, resourceType := range names { + fmt.Println(resourceType) + } + }, + } + + return cmd +} diff --git a/cmd/scan.go b/cmd/scan.go new file mode 100644 index 00000000..e0138713 --- /dev/null +++ b/cmd/scan.go @@ -0,0 +1,88 @@ +package cmd + +import ( + "context" + "fmt" + "runtime/debug" + + "github.com/ekristen/aws-nuke/v3/pkg/awsutil" + "github.com/ekristen/aws-nuke/v3/pkg/util" + "github.com/ekristen/aws-nuke/v3/resources" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" +) + +const ScannerParallelQueries = 16 + +func Scan(region *Region, resourceTypes []string) <-chan *Item { + s := &scanner{ + items: make(chan *Item, 100), + semaphore: semaphore.NewWeighted(ScannerParallelQueries), + } + go s.run(region, resourceTypes) + + return s.items +} + +type scanner struct { + items chan *Item + semaphore *semaphore.Weighted +} + +func (s *scanner) run(region *Region, resourceTypes []string) { + ctx := context.Background() + + for _, resourceType := range resourceTypes { + s.semaphore.Acquire(ctx, 1) + go s.list(region, resourceType) + } + + // Wait for all routines to finish. + s.semaphore.Acquire(ctx, ScannerParallelQueries) + + close(s.items) +} + +func (s *scanner) list(region *Region, resourceType string) { + defer func() { + if r := recover(); r != nil { + err := fmt.Errorf("%v\n\n%s", r.(error), string(debug.Stack())) + dump := util.Indent(fmt.Sprintf("%v", err), " ") + log.Errorf("Listing %s failed:\n%s", resourceType, dump) + } + }() + defer s.semaphore.Release(1) + + lister := resources.GetLister(resourceType) + var rs []resources.Resource + sess, err := region.Session(resourceType) + if err == nil { + rs, err = lister(sess) + } + if err != nil { + _, ok := err.(awsutil.ErrSkipRequest) + if ok { + log.Debugf("skipping request: %v", err) + return + } + + _, ok = err.(awsutil.ErrUnknownEndpoint) + if ok { + log.Warnf("skipping request: %v", err) + return + } + + dump := util.Indent(fmt.Sprintf("%v", err), " ") + log.Errorf("Listing %s failed:\n%s", resourceType, dump) + return + } + + for _, r := range rs { + s.items <- &Item{ + Region: region, + Resource: r, + State: ItemStateNew, + Type: resourceType, + } + } +} diff --git a/cmd/util.go b/cmd/util.go new file mode 100644 index 00000000..f1d7a73f --- /dev/null +++ b/cmd/util.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "bufio" + "fmt" + "os" + "strings" + + "github.com/ekristen/aws-nuke/v3/pkg/types" +) + +func Prompt(expect string) error { + fmt.Print("> ") + reader := bufio.NewReader(os.Stdin) + text, err := reader.ReadString('\n') + if err != nil { + return err + } + + if strings.TrimSpace(text) != expect { + return fmt.Errorf("aborted") + } + fmt.Println() + + return nil +} + +func ResolveResourceTypes( + base types.Collection, mapping map[string]string, + include, exclude, cloudControl []types.Collection) types.Collection { + + for _, cl := range cloudControl { + oldStyle := types.Collection{} + for _, c := range cl { + os, found := mapping[c] + if found { + oldStyle = append(oldStyle, os) + } + } + + base = base.Union(cl) + base = base.Remove(oldStyle) + } + + for _, i := range include { + if len(i) > 0 { + base = base.Intersect(i) + } + } + + for _, e := range exclude { + base = base.Remove(e) + } + + return base +} + +func IsTrue(s string) bool { + return strings.TrimSpace(strings.ToLower(s)) == "true" +} diff --git a/cmd/util_test.go b/cmd/util_test.go new file mode 100644 index 00000000..a279d8e4 --- /dev/null +++ b/cmd/util_test.go @@ -0,0 +1,95 @@ +package cmd + +import ( + "fmt" + "sort" + "testing" + + "github.com/ekristen/aws-nuke/v3/pkg/types" +) + +func TestResolveResourceTypes(t *testing.T) { + cases := []struct { + name string + base types.Collection + mapping map[string]string + include []types.Collection + exclude []types.Collection + cloudControl []types.Collection + result types.Collection + }{ + { + base: types.Collection{"a", "b", "c", "d"}, + include: []types.Collection{{"a", "b", "c"}}, + result: types.Collection{"a", "b", "c"}, + }, + { + base: types.Collection{"a", "b", "c", "d"}, + exclude: []types.Collection{{"b", "d"}}, + result: types.Collection{"a", "c"}, + }, + { + base: types.Collection{"a", "b"}, + include: []types.Collection{{}}, + result: types.Collection{"a", "b"}, + }, + { + base: types.Collection{"c", "b"}, + exclude: []types.Collection{{}}, + result: types.Collection{"c", "b"}, + }, + { + base: types.Collection{"a", "b", "c", "d"}, + include: []types.Collection{{"a", "b", "c"}}, + exclude: []types.Collection{{"a"}}, + result: types.Collection{"b", "c"}, + }, + { + name: "CloudControlAdd", + base: types.Collection{"a", "b"}, + cloudControl: []types.Collection{{"x"}}, + result: types.Collection{"a", "b", "x"}, + }, + { + name: "CloudControlReplaceOldStyle", + base: types.Collection{"a", "b", "c"}, + mapping: map[string]string{"z": "b"}, + cloudControl: []types.Collection{{"z"}}, + result: types.Collection{"a", "z", "c"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := ResolveResourceTypes(tc.base, tc.mapping, tc.include, tc.exclude, tc.cloudControl) + + sort.Strings(r) + sort.Strings(tc.result) + + var ( + want = fmt.Sprint(tc.result) + have = fmt.Sprint(r) + ) + + if want != have { + t.Fatalf("Wrong result. Want: %s. Have: %s", want, have) + } + }) + } +} + +func TestIsTrue(t *testing.T) { + falseStrings := []string{"", "false", "treu", "foo"} + for _, fs := range falseStrings { + if IsTrue(fs) { + t.Fatalf("IsTrue falsely returned 'true' for: %s", fs) + } + } + + trueStrings := []string{"true", " true", "true ", " TrUe "} + for _, ts := range trueStrings { + if !IsTrue(ts) { + t.Fatalf("IsTrue falsely returned 'false' for: %s", ts) + } + } +} diff --git a/cmd/version.go b/cmd/version.go new file mode 100644 index 00000000..a4479752 --- /dev/null +++ b/cmd/version.go @@ -0,0 +1,35 @@ +package cmd + +import ( + "fmt" + "runtime/debug" + + "github.com/spf13/cobra" +) + +var ( + BuildVersion = "unknown" + BuildDate = "unknown" + BuildHash = "unknown" + BuildEnvironment = "unknown" +) + +func NewVersionCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "version", + Short: "shows version of this application", + Run: func(cmd *cobra.Command, args []string) { + fmt.Printf("version: %s\n", BuildVersion) + fmt.Printf("build date: %s\n", BuildDate) + fmt.Printf("scm hash: %s\n", BuildHash) + fmt.Printf("environment: %s\n", BuildEnvironment) + + bi, ok := debug.ReadBuildInfo() + if ok && bi != nil { + fmt.Printf("go version: %s\n", bi.GoVersion) + } + }, + } + + return cmd +} diff --git a/go.mod b/go.mod index cd60d0de..a46f4474 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/gotidy/ptr v1.4.0 github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v2 v2.27.5 go.uber.org/ratelimit v0.3.1 @@ -40,6 +41,7 @@ require ( github.com/benbjohnson/clock v1.3.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect @@ -48,6 +50,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect github.com/stevenle/topsort v0.2.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect golang.org/x/mod v0.17.0 // indirect diff --git a/go.sum b/go.sum index 76269efa..9d5b27f4 100644 --- a/go.sum +++ b/go.sum @@ -136,6 +136,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gotidy/ptr v1.4.0 h1:7++suUs+HNHMnyz6/AW3SE+4EnBhupPSQTSI7QNijVc= github.com/gotidy/ptr v1.4.0/go.mod h1:MjRBG6/IETiiZGWI8LrRtISXEji+8b/jigmj2q0mEyM= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -166,6 +168,10 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stevenle/topsort v0.2.0 h1:LLWgtp34HPX6/RBDRS0kElVxGOTzGBLI1lSAa5Lb46k= github.com/stevenle/topsort v0.2.0/go.mod h1:ck2WG2/ZrOr6dLApQ/5Xrqy5wv3T0qhKYWE7r9tkibc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pkg/awsutil/errors.go b/pkg/awsutil/errors.go index d6d91de6..6059405c 100644 --- a/pkg/awsutil/errors.go +++ b/pkg/awsutil/errors.go @@ -2,3 +2,15 @@ package awsutil const ErrCodeInvalidAction = "InvalidAction" const ErrCodeOperationNotPermitted = "OperationNotPermitted" + +type ErrSkipRequest string + +func (err ErrSkipRequest) Error() string { + return string(err) +} + +type ErrUnknownEndpoint string + +func (err ErrUnknownEndpoint) Error() string { + return string(err) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 88d1c9cb..3ced373d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,16 +1,72 @@ package config import ( + "bytes" "fmt" + "io/ioutil" "os" "strings" "gopkg.in/yaml.v3" + "github.com/ekristen/aws-nuke/v3/pkg/types" + log "github.com/sirupsen/logrus" + "github.com/ekristen/libnuke/pkg/config" "github.com/ekristen/libnuke/pkg/settings" ) +type ResourceTypes struct { + Targets types.Collection `yaml:"targets"` + Excludes types.Collection `yaml:"excludes"` + CloudControl types.Collection `yaml:"cloud-control"` +} + +func Load(path string) (*Nuke, error) { + var err error + + raw, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + config := new(Nuke) + dec := yaml.NewDecoder(bytes.NewReader(raw)) + dec.KnownFields(true) + err = dec.Decode(&config) + if err != nil { + return nil, err + } + + if err := config.resolveDeprecations(); err != nil { + return nil, err + } + + return config, nil +} + +type Account struct { + Filters Filters `yaml:"filters"` + ResourceTypes ResourceTypes `yaml:"resource-types"` + Presets []string `yaml:"presets"` +} + +type PresetDefinitions struct { + Filters Filters `yaml:"filters"` +} + +type Nuke struct { + // Deprecated: Use AccountBlocklist instead. + AccountBlacklist []string `yaml:"account-blacklist"` + AccountBlocklist []string `yaml:"account-blocklist"` + Regions []string `yaml:"regions"` + Accounts map[string]Account `yaml:"accounts"` + ResourceTypes ResourceTypes `yaml:"resource-types"` + Presets map[string]PresetDefinitions `yaml:"presets"` + FeatureFlags FeatureFlags `yaml:"feature-flags"` + CustomEndpoints CustomEndpoints `yaml:"endpoints"` +} + // New creates a new extended configuration from a file. This is necessary because we are extended the default // libnuke configuration to contain additional attributes that are specific to the AWS Nuke tool. func New(opts config.Options) (*Config, error) { @@ -256,3 +312,134 @@ func (endpoints CustomEndpoints) GetURL(region, serviceType string) string { } return s.URL } + +func (c *Nuke) ResolveBlocklist() []string { + if c.AccountBlocklist != nil { + return c.AccountBlocklist + } + + log.Warn("deprecated configuration key 'account-blacklist' - please use 'account-blocklist' instead") + return c.AccountBlacklist +} + +func (c *Nuke) HasBlocklist() bool { + var blocklist = c.ResolveBlocklist() + return blocklist != nil && len(blocklist) > 0 +} + +func (c *Nuke) InBlocklist(searchID string) bool { + for _, blocklistID := range c.ResolveBlocklist() { + if blocklistID == searchID { + return true + } + } + + return false +} + +func (c *Nuke) ValidateAccount(accountID string, aliases []string) error { + if !c.HasBlocklist() { + return fmt.Errorf("The config file contains an empty blocklist. " + + "For safety reasons you need to specify at least one account ID. " + + "This should be your production account.") + } + + if c.InBlocklist(accountID) { + return fmt.Errorf("You are trying to nuke the account with the ID %s, "+ + "but it is blocklisted. Aborting.", accountID) + } + + if len(aliases) == 0 { + return fmt.Errorf("The specified account doesn't have an alias. " + + "For safety reasons you need to specify an account alias. " + + "Your production account should contain the term 'prod'.") + } + + for _, alias := range aliases { + if strings.Contains(strings.ToLower(alias), "prod") { + return fmt.Errorf("You are trying to nuke an account with the alias '%s', "+ + "but it has the substring 'prod' in it. Aborting.", alias) + } + } + + if _, ok := c.Accounts[accountID]; !ok { + return fmt.Errorf("Your account ID '%s' isn't listed in the config. "+ + "Aborting.", accountID) + } + + return nil +} + +func (c *Nuke) Filters(accountID string) (Filters, error) { + account := c.Accounts[accountID] + filters := account.Filters + + if filters == nil { + filters = Filters{} + } + + if account.Presets == nil { + return filters, nil + } + + for _, presetName := range account.Presets { + notFound := fmt.Errorf("Could not find filter preset '%s'", presetName) + if c.Presets == nil { + return nil, notFound + } + + preset, ok := c.Presets[presetName] + if !ok { + return nil, notFound + } + + filters.Merge(preset.Filters) + } + + return filters, nil +} + +func (c *Nuke) resolveDeprecations() error { + deprecations := map[string]string{ + "EC2DhcpOptions": "EC2DHCPOptions", + "EC2InternetGatewayAttachement": "EC2InternetGatewayAttachment", + "EC2NatGateway": "EC2NATGateway", + "EC2Vpc": "EC2VPC", + "EC2VpcEndpoint": "EC2VPCEndpoint", + "EC2VpnConnection": "EC2VPNConnection", + "EC2VpnGateway": "EC2VPNGateway", + "EC2VpnGatewayAttachement": "EC2VPNGatewayAttachment", + "ECRrepository": "ECRRepository", + "IamGroup": "IAMGroup", + "IamGroupPolicyAttachement": "IAMGroupPolicyAttachment", + "IamInstanceProfile": "IAMInstanceProfile", + "IamInstanceProfileRole": "IAMInstanceProfileRole", + "IamPolicy": "IAMPolicy", + "IamRole": "IAMRole", + "IamRolePolicyAttachement": "IAMRolePolicyAttachment", + "IamServerCertificate": "IAMServerCertificate", + "IamUser": "IAMUser", + "IamUserAccessKeys": "IAMUserAccessKey", + "IamUserGroupAttachement": "IAMUserGroupAttachment", + "IamUserPolicyAttachement": "IAMUserPolicyAttachment", + "RDSCluster": "RDSDBCluster", + } + + for _, a := range c.Accounts { + for resourceType, resources := range a.Filters { + replacement, ok := deprecations[resourceType] + if !ok { + continue + } + log.Warnf("deprecated resource type '%s' - converting to '%s'\n", resourceType, replacement) + + if _, ok := a.Filters[replacement]; ok { + return fmt.Errorf("using deprecated resource type and replacement: '%s','%s'", resourceType, replacement) + } + + a.Filters[replacement] = resources + delete(a.Filters, resourceType) + } + } + return nil +} diff --git a/pkg/config/filter.go b/pkg/config/filter.go new file mode 100644 index 00000000..019ed8eb --- /dev/null +++ b/pkg/config/filter.go @@ -0,0 +1,131 @@ +package config + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/mb0/glob" +) + +type FilterType string + +const ( + FilterTypeEmpty FilterType = "" + FilterTypeExact = "exact" + FilterTypeGlob = "glob" + FilterTypeRegex = "regex" + FilterTypeContains = "contains" + FilterTypeDateOlderThan = "dateOlderThan" +) + +type Filters map[string][]Filter + +func (f Filters) Merge(f2 Filters) { + for resourceType, filter := range f2 { + f[resourceType] = append(f[resourceType], filter...) + } +} + +type Filter struct { + Property string + Type FilterType + Value string + Invert string +} + +func (f Filter) Match(o string) (bool, error) { + switch f.Type { + case FilterTypeEmpty: + fallthrough + + case FilterTypeExact: + return f.Value == o, nil + + case FilterTypeContains: + return strings.Contains(o, f.Value), nil + + case FilterTypeGlob: + return glob.Match(f.Value, o) + + case FilterTypeRegex: + re, err := regexp.Compile(f.Value) + if err != nil { + return false, err + } + return re.MatchString(o), nil + + case FilterTypeDateOlderThan: + if o == "" { + return false, nil + } + duration, err := time.ParseDuration(f.Value) + if err != nil { + return false, err + } + fieldTime, err := parseDate(o) + if err != nil { + return false, err + } + fieldTimeWithOffset := fieldTime.Add(duration) + + return fieldTimeWithOffset.After(time.Now()), nil + + default: + return false, fmt.Errorf("unknown type %s", f.Type) + } +} + +func parseDate(input string) (time.Time, error) { + if i, err := strconv.ParseInt(input, 10, 64); err == nil { + t := time.Unix(i, 0) + return t, nil + } + + formats := []string{ + "2006-01-02", + "2006/01/02", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05 -0700 MST", // Date format used by AWS for CreateTime on ASGs + time.RFC3339Nano, // Format of t.MarshalText() and t.MarshalJSON() + time.RFC3339, + } + for _, f := range formats { + t, err := time.Parse(f, input) + if err == nil { + return t, nil + } + } + return time.Now(), fmt.Errorf("unable to parse time %s", input) +} + +func (f *Filter) UnmarshalYAML(unmarshal func(interface{}) error) error { + var value string + + if unmarshal(&value) == nil { + f.Type = FilterTypeExact + f.Value = value + return nil + } + + m := map[string]string{} + err := unmarshal(m) + if err != nil { + return err + } + + f.Type = FilterType(m["type"]) + f.Value = m["value"] + f.Property = m["property"] + f.Invert = m["invert"] + return nil +} + +func NewExactFilter(value string) Filter { + return Filter{ + Type: FilterTypeExact, + Value: value, + } +} diff --git a/pkg/types/collection.go b/pkg/types/collection.go new file mode 100644 index 00000000..408a696b --- /dev/null +++ b/pkg/types/collection.go @@ -0,0 +1,50 @@ +package types + +type Collection []string + +func (c Collection) Intersect(o Collection) Collection { + mo := o.toMap() + + result := Collection{} + for _, t := range c { + if mo[t] { + result = append(result, t) + } + } + + return result +} + +func (c Collection) Remove(o Collection) Collection { + mo := o.toMap() + + result := Collection{} + for _, t := range c { + if !mo[t] { + result = append(result, t) + } + } + + return result +} + +func (c Collection) Union(o Collection) Collection { + ms := c.toMap() + + result := []string(c) + for _, oi := range o { + if !ms[oi] { + result = append(result, oi) + } + } + + return Collection(result) +} + +func (c Collection) toMap() map[string]bool { + m := map[string]bool{} + for _, t := range c { + m[t] = true + } + return m +} diff --git a/pkg/types/collection_test.go b/pkg/types/collection_test.go new file mode 100644 index 00000000..d454073c --- /dev/null +++ b/pkg/types/collection_test.go @@ -0,0 +1,50 @@ +package types_test + +import ( + "fmt" + "testing" + + "github.com/ekristen/aws-nuke/v3/pkg/types" +) + +func TestSetInterset(t *testing.T) { + s1 := types.Collection{"a", "b", "c"} + s2 := types.Collection{"b", "a", "d"} + + r := s1.Intersect(s2) + + want := fmt.Sprint([]string{"a", "b"}) + have := fmt.Sprint(r) + + if want != have { + t.Errorf("Wrong result. Want: %s. Have: %s", want, have) + } +} + +func TestSetRemove(t *testing.T) { + s1 := types.Collection{"a", "b", "c"} + s2 := types.Collection{"b", "a", "d"} + + r := s1.Remove(s2) + + want := fmt.Sprint([]string{"c"}) + have := fmt.Sprint(r) + + if want != have { + t.Errorf("Wrong result. Want: %s. Have: %s", want, have) + } +} + +func TestSetUnion(t *testing.T) { + s1 := types.Collection{"a", "b", "c"} + s2 := types.Collection{"b", "a", "d"} + + r := s1.Union(s2) + + want := fmt.Sprint([]string{"a", "b", "c", "d"}) + have := fmt.Sprint(r) + + if want != have { + t.Errorf("Wrong result. Want: %s. Have: %s", want, have) + } +} diff --git a/pkg/types/properties.go b/pkg/types/properties.go new file mode 100644 index 00000000..bb2c27e9 --- /dev/null +++ b/pkg/types/properties.go @@ -0,0 +1,137 @@ +package types + +import ( + "fmt" + "sort" + "strings" +) + +type Properties map[string]string + +func NewProperties() Properties { + return make(Properties) +} + +func (p Properties) String() string { + parts := []string{} + for k, v := range p { + parts = append(parts, fmt.Sprintf(`%s: "%v"`, k, v)) + } + + sort.Strings(parts) + + return fmt.Sprintf("[%s]", strings.Join(parts, ", ")) +} + +func (p Properties) Set(key string, value interface{}) Properties { + if value == nil { + return p + } + + switch v := value.(type) { + case *string: + if v == nil { + return p + } + p[key] = *v + case []byte: + p[key] = string(v) + case *bool: + if v == nil { + return p + } + p[key] = fmt.Sprint(*v) + case *int64: + if v == nil { + return p + } + p[key] = fmt.Sprint(*v) + case *int: + if v == nil { + return p + } + p[key] = fmt.Sprint(*v) + default: + // Fallback to Stringer interface. This produces gibberish on pointers, + // but is the only way to avoid reflection. + p[key] = fmt.Sprint(value) + } + + return p +} + +func (p Properties) SetTag(tagKey *string, tagValue interface{}) Properties { + return p.SetTagWithPrefix("", tagKey, tagValue) +} + +func (p Properties) SetTagWithPrefix(prefix string, tagKey *string, tagValue interface{}) Properties { + if tagKey == nil { + return p + } + + keyStr := strings.TrimSpace(*tagKey) + prefix = strings.TrimSpace(prefix) + + if keyStr == "" { + return p + } + + if prefix != "" { + keyStr = fmt.Sprintf("%s:%s", prefix, keyStr) + } + + keyStr = fmt.Sprintf("tag:%s", keyStr) + + return p.Set(keyStr, tagValue) +} + +func (p Properties) SetPropertyWithPrefix(prefix string, propertyKey string, propertyValue interface{}) Properties { + keyStr := strings.TrimSpace(propertyKey) + prefix = strings.TrimSpace(prefix) + + if keyStr == "" { + return p + } + + if prefix != "" { + keyStr = fmt.Sprintf("%s:%s", prefix, keyStr) + } + + return p.Set(keyStr, propertyValue) +} + +func (p Properties) Get(key string) string { + value, ok := p[key] + if !ok { + return "" + } + + return value +} + +func (p Properties) Equals(o Properties) bool { + if p == nil && o == nil { + return true + } + + if p == nil || o == nil { + return false + } + + if len(p) != len(o) { + return false + } + + for k, pv := range p { + ov, ok := o[k] + if !ok { + return false + } + + if pv != ov { + return false + } + } + + return true +} diff --git a/pkg/types/properties_test.go b/pkg/types/properties_test.go new file mode 100644 index 00000000..ad78d765 --- /dev/null +++ b/pkg/types/properties_test.go @@ -0,0 +1,201 @@ +package types_test + +import ( + "fmt" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/ekristen/aws-nuke/v3/pkg/types" +) + +func TestPropertiesEquals(t *testing.T) { + cases := []struct { + p1, p2 types.Properties + result bool + }{ + { + p1: nil, + p2: nil, + result: true, + }, + { + p1: nil, + p2: types.NewProperties(), + result: false, + }, + { + p1: types.NewProperties(), + p2: types.NewProperties(), + result: true, + }, + { + p1: types.NewProperties().Set("blub", "blubber"), + p2: types.NewProperties().Set("blub", "blubber"), + result: true, + }, + { + p1: types.NewProperties().Set("blub", "foo"), + p2: types.NewProperties().Set("blub", "bar"), + result: false, + }, + { + p1: types.NewProperties().Set("bim", "baz").Set("blub", "blubber"), + p2: types.NewProperties().Set("bim", "baz").Set("blub", "blubber"), + result: true, + }, + { + p1: types.NewProperties().Set("bim", "baz").Set("blub", "foo"), + p2: types.NewProperties().Set("bim", "baz").Set("blub", "bar"), + result: false, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprint(i), func(t *testing.T) { + if tc.p1.Equals(tc.p2) != tc.result { + t.Errorf("Test Case failed. Want %t. Got %t.", !tc.result, tc.result) + t.Errorf("p1: %s", tc.p1.String()) + t.Errorf("p2: %s", tc.p2.String()) + } else if tc.p2.Equals(tc.p1) != tc.result { + t.Errorf("Test Case reverse check failed. Want %t. Got %t.", !tc.result, tc.result) + t.Errorf("p1: %s", tc.p1.String()) + t.Errorf("p2: %s", tc.p2.String()) + } + }) + } +} + +func TestPropertiesSetTag(t *testing.T) { + cases := []struct { + name string + key *string + value interface{} + want string + }{ + { + name: "string", + key: aws.String("name"), + value: "blubber", + want: `[tag:name: "blubber"]`, + }, + { + name: "string_ptr", + key: aws.String("name"), + value: aws.String("blubber"), + want: `[tag:name: "blubber"]`, + }, + { + name: "int", + key: aws.String("int"), + value: 42, + want: `[tag:int: "42"]`, + }, + { + name: "nil", + key: aws.String("nothing"), + value: nil, + want: `[]`, + }, + { + name: "empty_key", + key: aws.String(""), + value: "empty", + want: `[]`, + }, + { + name: "nil_key", + key: nil, + value: "empty", + want: `[]`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + p := types.NewProperties() + + p.SetTag(tc.key, tc.value) + have := p.String() + + if tc.want != have { + t.Errorf("'%s' != '%s'", tc.want, have) + } + }) + } +} + +func TestPropertiesSetTagWithPrefix(t *testing.T) { + cases := []struct { + name string + prefix string + key *string + value interface{} + want string + }{ + { + name: "empty", + prefix: "", + key: aws.String("name"), + value: "blubber", + want: `[tag:name: "blubber"]`, + }, + { + name: "nonempty", + prefix: "bish", + key: aws.String("bash"), + value: "bosh", + want: `[tag:bish:bash: "bosh"]`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + p := types.NewProperties() + + p.SetTagWithPrefix(tc.prefix, tc.key, tc.value) + have := p.String() + + if tc.want != have { + t.Errorf("'%s' != '%s'", tc.want, have) + } + }) + } +} + +func TestPropertiesSetPropertiesWithPrefix(t *testing.T) { + cases := []struct { + name string + prefix string + key string + value interface{} + want string + }{ + { + name: "empty", + prefix: "", + key: "OwnerID", + value: aws.String("123456789012"), + want: `[OwnerID: "123456789012"]`, + }, + { + name: "nonempty", + prefix: "igw", + key: "OwnerID", + value: aws.String("123456789012"), + want: `[igw:OwnerID: "123456789012"]`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + p := types.NewProperties() + + p.SetPropertyWithPrefix(tc.prefix, tc.key, tc.value) + have := p.String() + + if tc.want != have { + t.Errorf("'%s' != '%s'", tc.want, have) + } + }) + } +} diff --git a/pkg/util/indent.go b/pkg/util/indent.go new file mode 100644 index 00000000..2a09e535 --- /dev/null +++ b/pkg/util/indent.go @@ -0,0 +1,18 @@ +package util + +func Indent(s, prefix string) string { + return string(IndentBytes([]byte(s), []byte(prefix))) +} + +func IndentBytes(b, prefix []byte) []byte { + var res []byte + bol := true + for _, c := range b { + if bol && c != '\n' { + res = append(res, prefix...) + } + res = append(res, c) + bol = c == '\n' + } + return res +} diff --git a/resources/interface.go b/resources/interface.go new file mode 100644 index 00000000..310808c5 --- /dev/null +++ b/resources/interface.go @@ -0,0 +1,88 @@ +package resources + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/ekristen/aws-nuke/v3/pkg/config" + "github.com/ekristen/aws-nuke/v3/pkg/types" +) + +type ResourceListers map[string]ResourceLister + +type ResourceLister func(s *session.Session) ([]Resource, error) + +type Resource interface { + Remove() error +} + +type Filter interface { + Resource + Filter() error +} + +type LegacyStringer interface { + Resource + String() string +} + +type ResourcePropertyGetter interface { + Resource + Properties() types.Properties +} + +type FeatureFlagGetter interface { + Resource + FeatureFlags(config.FeatureFlags) +} + +var resourceListers = make(ResourceListers) + +func register(name string, lister ResourceLister, opts ...registerOption) { + _, exists := resourceListers[name] + if exists { + panic(fmt.Sprintf("a resource with the name %s already exists", name)) + } + + resourceListers[name] = lister + + for _, opt := range opts { + opt(name, lister) + } +} + +var cloudControlMapping = map[string]string{} + +func GetCloudControlMapping() map[string]string { + return cloudControlMapping +} + +type registerOption func(name string, lister ResourceLister) + +func mapCloudControl(typeName string) registerOption { + return func(name string, lister ResourceLister) { + _, exists := cloudControlMapping[typeName] + if exists { + panic(fmt.Sprintf("a cloud control mapping for %s already exists", typeName)) + } + + cloudControlMapping[typeName] = name + } +} + +func GetLister(name string) ResourceLister { + if strings.HasPrefix(name, "AWS::") { + registerCloudControl(name) + } + return resourceListers[name] +} + +func GetListerNames() []string { + names := []string{} + for resourceType := range resourceListers { + names = append(names, resourceType) + } + + return names +}