From 29df26b5c2e2db1e4c8d63f9f6f14db7d6cf7947 Mon Sep 17 00:00:00 2001 From: sacha-c Date: Thu, 5 Dec 2024 09:29:05 +0100 Subject: [PATCH] feat(#17): rework cli interface --- README.md | 7 +- internal/cli/patrol.go | 133 ++++++++++++++-------------- internal/cli/patrol_test.go | 49 ++++------- internal/gitlab/client.go | 12 +-- internal/gitlab/gitlab.go | 137 ++++++++--------------------- internal/gitlab/gitlab_test.go | 47 +++++----- internal/patrol/patrol.go | 54 +++++++++--- internal/patrol/patrol_test.go | 40 +++++++-- internal/publish/to_gitlab_test.go | 4 +- internal/publish/to_slack.go | 10 +-- internal/publish/to_slack_test.go | 4 +- internal/slack/slack.go | 12 +-- internal/slack/slack_test.go | 56 +++++------- 13 files changed, 261 insertions(+), 304 deletions(-) diff --git a/README.md b/README.md index 9263e94..1d2c137 100644 --- a/README.md +++ b/README.md @@ -107,10 +107,9 @@ Only the **Reporting** and **Scanning** sections of configuration parameters are In this case you may choose to create a config file such as the following: ```toml -gitlab-groups = ["namespace/group", "namespace/group/cool-repo"] -gitlab-projects = ["namespace/group/cool-repo"] -report-slack-channel = "sheriff-report-test" -report-gitlab-issue = true +url = ["namespace/group", "namespace/group/cool-repo"] +report-to-slack-channel = "sheriff-report-test" +report-to-gitlab-issue = true ``` And if you wish to specify a different file, you can do so with `sheriff patrol --config your-config-file.toml`. diff --git a/internal/cli/patrol.go b/internal/cli/patrol.go index ed88d67..4e09ca8 100644 --- a/internal/cli/patrol.go +++ b/internal/cli/patrol.go @@ -3,24 +3,17 @@ package cli import ( "errors" "fmt" - "regexp" + "net/url" "sheriff/internal/git" "sheriff/internal/gitlab" "sheriff/internal/patrol" "sheriff/internal/scanner" "sheriff/internal/slack" - zerolog "github.com/rs/zerolog/log" "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" ) -// Regexes very loosely defined based on GitLab's reserved names: -// https://docs.gitlab.com/ee/user/reserved_names.html#limitations-on-usernames-project-and-group-names-and-slugs -// In reality the regex should be more restrictive about special characters, for now we're just checking for slashes and non-whitespace characters. -const groupPathRegex = "^\\S+(\\/\\S+)*$" // Matches paths like "group" or "group/subgroup" ... -const projectPathRegex = "^\\S+(\\/\\S+)+$" // Matches paths like "group/project" or "group/subgroup/project" ... - type CommandCategory string const ( @@ -32,14 +25,12 @@ const ( const configFlag = "config" const verboseFlag = "verbose" -const testingFlag = "testing" -const groupsFlag = "gitlab-groups" -const projectsFlag = "gitlab-projects" -const reportSlackChannelFlag = "report-slack-channel" -const reportSlackProjectChannelFlag = "report-slack-project-channel" -const reportGitlabFlag = "report-gitlab-issue" -const silentReport = "silent" -const publicSlackChannelFlag = "public-slack-channel" +const urlFlag = "url" +const reportToEmailFlag = "report-to-email" +const reportToIssueFlag = "report-to-issue" +const reportToSlackChannel = "report-to-slack-channel" +const reportEnableProjectReportToFlag = "report-enable-project-report-to" +const silentReportFlag = "silent" const gitlabTokenFlag = "gitlab-token" const slackTokenFlag = "slack-token" @@ -47,58 +38,46 @@ var sensitiveFlags = []string{gitlabTokenFlag, slackTokenFlag} var PatrolFlags = []cli.Flag{ &cli.StringFlag{ - Name: configFlag, - Value: "sheriff.toml", + Name: configFlag, + Aliases: []string{"c"}, + Value: "sheriff.toml", }, &cli.BoolFlag{ Name: verboseFlag, + Aliases: []string{"v"}, Usage: "Enable verbose logging", Category: string(Miscellaneous), Value: false, }, altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ - Name: groupsFlag, - Usage: "Gitlab groups to scan for vulnerabilities (list argument which can be repeated)", + Name: urlFlag, + Usage: "Groups and projects to scan for vulnerabilities (list argument which can be repeated)", Category: string(Scanning), - Action: validatePaths(groupPathRegex), }), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ - Name: projectsFlag, - Usage: "Gitlab projects to scan for vulnerabilities (list argument which can be repeated)", - Category: string(Scanning), - Action: validatePaths(projectPathRegex), - }), - altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: testingFlag, - Usage: "Enable testing mode. This can enable features that are not safe for production use.", - Category: string(Miscellaneous), - Value: false, - }), - altsrc.NewStringFlag(&cli.StringFlag{ - Name: reportSlackChannelFlag, - Usage: "Enable reporting to Slack through messages in the specified channel.", + Name: reportToEmailFlag, + Usage: "Enable reporting to the provided list of emails", Category: string(Reporting), }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: reportSlackProjectChannelFlag, - Usage: "Enable reporting to Slack through messages in the specified project's channel. Requires a project-level configuration file specifying the channel.", + Name: reportToIssueFlag, + Usage: "Enable or disable reporting to the project's issue on the associated platform (gitlab, github, ...)", Category: string(Reporting), }), - altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: reportGitlabFlag, - Usage: "Enable reporting to GitLab through issue creation in projects affected by vulnerabilities.", + altsrc.NewStringFlag(&cli.StringFlag{ + Name: reportToSlackChannel, + Usage: "Enable reporting to the provided slack channel", Category: string(Reporting), - Value: false, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: silentReport, - Usage: "Disable report output to stdout.", + Name: reportEnableProjectReportToFlag, + Usage: "Enable project-level configuration for '--report-to'.", Category: string(Reporting), - Value: false, + Value: true, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: publicSlackChannelFlag, - Usage: "Allow the slack report to be posted to a public channel. Note that reports may contain sensitive information which should not be disclosed on a public channel, for this reason this flag will only be enabled when combined with the testing flag.", + Name: silentReportFlag, + Usage: "Disable report output to stdout.", Category: string(Reporting), Value: false, }), @@ -121,10 +100,10 @@ var PatrolFlags = []cli.Flag{ func PatrolAction(cCtx *cli.Context) error { verbose := cCtx.Bool(verboseFlag) - var publicChannelsEnabled bool - if cCtx.Bool(testingFlag) { - zerolog.Warn().Msg("Testing mode enabled. This may enable features that are not safe for production use.") - publicChannelsEnabled = cCtx.Bool(publicSlackChannelFlag) + // Parse options + locations, err := parseUrls(cCtx.StringSlice(urlFlag)) + if err != nil { + return errors.Join(errors.New("failed to parse `--url` options"), err) } // Create services @@ -133,7 +112,7 @@ func PatrolAction(cCtx *cli.Context) error { return errors.Join(errors.New("failed to create GitLab service"), err) } - slackService, err := slack.New(cCtx.String(slackTokenFlag), publicChannelsEnabled, verbose) + slackService, err := slack.New(cCtx.String(slackTokenFlag), verbose) if err != nil { return errors.Join(errors.New("failed to create Slack service"), err) } @@ -145,13 +124,15 @@ func PatrolAction(cCtx *cli.Context) error { // Do the patrol if warn, err := patrolService.Patrol( - cCtx.StringSlice(groupsFlag), - cCtx.StringSlice(projectsFlag), - cCtx.Bool(reportGitlabFlag), - cCtx.String(reportSlackChannelFlag), - cCtx.Bool(reportSlackProjectChannelFlag), - cCtx.Bool(silentReport), - verbose, + patrol.PatrolArgs{ + Locations: locations, + ReportToIssue: cCtx.Bool(reportToIssueFlag), + ReportToEmails: cCtx.StringSlice(reportToEmailFlag), + ReportToSlackChannel: cCtx.String(reportToSlackChannel), + EnableProjectReportTo: cCtx.Bool(reportEnableProjectReportToFlag), + SilentReport: cCtx.Bool(silentReportFlag), + Verbose: verbose, + }, ); err != nil { return errors.Join(errors.New("failed to scan"), err) } else if warn != nil { @@ -161,20 +142,34 @@ func PatrolAction(cCtx *cli.Context) error { return nil } -func validatePaths(regex string) func(*cli.Context, []string) error { - return func(_ *cli.Context, groups []string) (err error) { - rgx, err := regexp.Compile(regex) - if err != nil { - return err +func parseUrls(uris []string) ([]patrol.ProjectLocation, error) { + locations := make([]patrol.ProjectLocation, len(uris)) + for i, uri := range uris { + parsed, err := url.Parse(uri) + if err != nil || parsed == nil { + return nil, errors.Join(fmt.Errorf("failed to parse uri"), err) } - for _, path := range groups { - matched := rgx.Match([]byte(path)) + if !parsed.IsAbs() { + return nil, fmt.Errorf("url missing platform scheme %v", uri) + } - if !matched { - return fmt.Errorf("invalid group path: %v", path) - } + if parsed.Scheme == string(patrol.Github) { + return nil, fmt.Errorf("github is currently unsupported, but is on our roadmap :)") // TODO #9 + } else if parsed.Scheme != string(patrol.Gitlab) { + return nil, fmt.Errorf("unsupport platform %v", parsed.Scheme) + } + + path, err := url.JoinPath(parsed.Host, parsed.Path) + if err != nil { + return nil, fmt.Errorf("failed to join host and path %v", uri) + } + + locations[i] = patrol.ProjectLocation{ + Type: patrol.PlatformType(parsed.Scheme), + Path: path, } - return } + + return locations, nil } diff --git a/internal/cli/patrol_test.go b/internal/cli/patrol_test.go index ad3412f..e5028c4 100644 --- a/internal/cli/patrol_test.go +++ b/internal/cli/patrol_test.go @@ -2,6 +2,8 @@ package cli import ( "flag" + "fmt" + "sheriff/internal/patrol" "testing" "github.com/stretchr/testify/assert" @@ -16,44 +18,31 @@ func TestPatrolActionEmptyRun(t *testing.T) { assert.Nil(t, err) } -func TestValidatePathGroupPathRegex(t *testing.T) { +func TestParseUrls(t *testing.T) { testCases := []struct { - paths []string - want bool + paths []string + wantProjectLocation *patrol.ProjectLocation + wantError bool }{ - {[]string{"group"}, true}, - {[]string{"group/subgroup"}, true}, - {[]string{"group/subgroup", "not a path"}, false}, + {[]string{"gitlab://namespace/project"}, &patrol.ProjectLocation{Type: "gitlab", Path: "namespace/project"}, false}, + {[]string{"gitlab://namespace/subgroup/project"}, &patrol.ProjectLocation{Type: "gitlab", Path: "namespace/subgroup/project"}, false}, + {[]string{"gitlab://namespace"}, &patrol.ProjectLocation{Type: "gitlab", Path: "namespace"}, false}, + {[]string{"github://organization"}, &patrol.ProjectLocation{Type: "github", Path: "organization"}, true}, + {[]string{"github://organization/project"}, &patrol.ProjectLocation{Type: "github", Path: "organization/project"}, true}, + {[]string{"unknown://namespace/project"}, nil, true}, + {[]string{"unknown://not a path"}, nil, true}, + {[]string{"not a url"}, nil, true}, } for _, tc := range testCases { - err := validatePaths(groupPathRegex)(nil, tc.paths) + urls, err := parseUrls(tc.paths) - if tc.want { - assert.Nil(t, err) - } else { - assert.NotNil(t, err) - } - } -} - -func TestValidatePathProjectPathRegex(t *testing.T) { - testCases := []struct { - paths []string - want bool - }{ - {[]string{"project"}, false}, // top-level projects don't exist - {[]string{"group/project"}, true}, - {[]string{"group/project", "not a path"}, false}, - } + fmt.Print(urls) - for _, tc := range testCases { - err := validatePaths(projectPathRegex)(nil, tc.paths) - - if tc.want { - assert.Nil(t, err, tc.paths) + if tc.wantError { + assert.NotNil(t, err) } else { - assert.NotNil(t, err, tc.paths) + assert.Equal(t, tc.wantProjectLocation, &(urls[0])) } } } diff --git a/internal/gitlab/client.go b/internal/gitlab/client.go index 451301a..eeb30b1 100644 --- a/internal/gitlab/client.go +++ b/internal/gitlab/client.go @@ -10,8 +10,8 @@ import ( ) type iclient interface { - ListGroups(opt *gitlab.ListGroupsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Group, *gitlab.Response, error) - ListGroupProjects(groupId int, opt *gitlab.ListGroupProjectsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Project, *gitlab.Response, error) + GetProject(pid interface{}, opt *gitlab.GetProjectOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Project, *gitlab.Response, error) + ListGroupProjects(gid interface{}, opt *gitlab.ListGroupProjectsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Project, *gitlab.Response, error) ListProjectIssues(projectId interface{}, opt *gitlab.ListProjectIssuesOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Issue, *gitlab.Response, error) CreateIssue(projectId interface{}, opt *gitlab.CreateIssueOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Issue, *gitlab.Response, error) UpdateIssue(projectId interface{}, issueId int, opt *gitlab.UpdateIssueOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Issue, *gitlab.Response, error) @@ -21,12 +21,12 @@ type client struct { client *gitlab.Client } -func (c *client) ListGroups(opt *gitlab.ListGroupsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Group, *gitlab.Response, error) { - return c.client.Groups.ListGroups(opt, options...) +func (c *client) GetProject(gid interface{}, opt *gitlab.GetProjectOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Project, *gitlab.Response, error) { + return c.client.Projects.GetProject(gid, opt, options...) } -func (c *client) ListGroupProjects(groupId int, opt *gitlab.ListGroupProjectsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Project, *gitlab.Response, error) { - return c.client.Groups.ListGroupProjects(groupId, opt, options...) +func (c *client) ListGroupProjects(gid interface{}, opt *gitlab.ListGroupProjectsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Project, *gitlab.Response, error) { + return c.client.Groups.ListGroupProjects(gid, opt, options...) } func (c *client) ListProjectIssues(projectId interface{}, opt *gitlab.ListProjectIssuesOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Issue, *gitlab.Response, error) { diff --git a/internal/gitlab/gitlab.go b/internal/gitlab/gitlab.go index 9814e5c..f3c2724 100644 --- a/internal/gitlab/gitlab.go +++ b/internal/gitlab/gitlab.go @@ -3,7 +3,6 @@ package gitlab import ( "errors" "fmt" - "strings" "sync" "github.com/elliotchance/pie/v2" @@ -15,7 +14,7 @@ const VulnerabilityIssueTitle = "Sheriff - 🚨 Vulnerability report" // IService is the interface of the GitLab service as needed by sheriff type IService interface { - GetProjectList(groupPaths []string, projectPaths []string) (projects []gitlab.Project, warn error) + GetProjectList(paths []string) (projects []gitlab.Project, warn error) CloseVulnerabilityIssue(project gitlab.Project) error OpenVulnerabilityIssue(project gitlab.Project, report string) (*gitlab.Issue, error) } @@ -36,24 +35,13 @@ func New(gitlabToken string) (IService, error) { return &s, nil } -func (s *service) GetProjectList(groupPaths []string, projectPaths []string) (projects []gitlab.Project, warn error) { - projects, pwarn := s.gatherProjects(projectPaths) +func (s *service) GetProjectList(paths []string) (projects []gitlab.Project, warn error) { + projects, pwarn := s.gatherProjectsFromGroupsOrProjects(paths) if pwarn != nil { pwarn = errors.Join(errors.New("errors occured when gathering projects"), pwarn) warn = errors.Join(pwarn, warn) } - groupsProjects, gpwarn := s.gatherGroupsProjects(groupPaths) - if gpwarn != nil { - gpwarn = errors.Join(errors.New("errors occured when gathering groups projects"), gpwarn) - warn = errors.Join(gpwarn, warn) - } - - projects = append(projects, groupsProjects...) - - // Filter unique projects -- there may be duplicates between groups, other groups and projects - projects = filterUniqueProjects(projects) - projectsNamespaces := pie.Map(projects, func(p gitlab.Project) string { return p.PathWithNamespace }) log.Info().Strs("projects", projectsNamespaces).Msg("Projects to scan") @@ -126,95 +114,48 @@ func (s *service) OpenVulnerabilityIssue(project gitlab.Project, report string) return } -func (s *service) getGroup(groupPath string) (*gitlab.Group, error) { - log.Info().Str("group", groupPath).Msg("Getting group") - groups, _, err := s.client.ListGroups(&gitlab.ListGroupsOptions{ - Search: gitlab.Ptr(groupPath), - }) - if err != nil { - return nil, errors.Join(fmt.Errorf("failed to fetch list of groups like %v", groupPath), err) - } - - for _, group := range groups { - if group.FullPath == groupPath { - return group, nil - } - } - - return nil, fmt.Errorf("group %v not found", groupPath) -} - -func (s *service) getProject(path string) (*gitlab.Project, error) { - log.Info().Str("path", path).Msg("Getting project") - - lastSlash := strings.LastIndex(path, "/") - - if lastSlash == -1 { - return nil, fmt.Errorf("invalid project path %v", path) - } - - groupPath := path[:lastSlash] - - group, err := s.getGroup(groupPath) - if err != nil { - return nil, errors.Join(fmt.Errorf("failed to fetch group %v", groupPath), err) - } - - projects, _, lgerr := s.listGroupProjects(group.ID) - if lgerr != nil { - return nil, errors.Join(fmt.Errorf("failed to fetch list of projects like %v", path), err) - } - for _, project := range projects { - if project.PathWithNamespace == path { - return &project, nil - } - } - - return nil, fmt.Errorf("project %v not found", path) -} - -func (s *service) gatherGroupsProjects(groupPaths []string) (projects []gitlab.Project, warn error) { - for _, groupPath := range groupPaths { - group, gerr := s.getGroup(groupPath) +// This function receives a list of paths which can be gitlab projects or groups +// and returns the list of projects within those paths and the list of projects contained within those groups and their subgroups. +func (s *service) gatherProjectsFromGroupsOrProjects(paths []string) (projects []gitlab.Project, warn error) { + for _, path := range paths { + gp, gpwarn, gerr := s.getProjectsFromGroupOrProject(path) if gerr != nil { - log.Error().Err(gerr).Str("group", groupPath).Msg("Failed to fetch group") - gerr = errors.Join(fmt.Errorf("failed to fetch group %v", groupPath), gerr) + log.Error().Err(gerr).Str("group", path).Msg("Failed to fetch group") + gerr = errors.Join(fmt.Errorf("failed to fetch group %v", path), gerr) warn = errors.Join(gerr, warn) continue } - - if groupProjects, gpwarn, gperr := s.listGroupProjects(group.ID); gperr != nil { - log.Error().Err(gpwarn).Str("group", groupPath).Msg("Failed to fetch projects of group") - gperr = errors.Join(fmt.Errorf("failed to fetch projects of group %v", groupPath), gperr) - warn = errors.Join(gperr, warn) - } else if gpwarn != nil { - gpwarn = errors.Join(fmt.Errorf("failed to fetch projects of group %v", groupPath), gpwarn) + if gpwarn != nil { warn = errors.Join(gpwarn, warn) - - projects = append(projects, groupProjects...) - } else { - projects = append(projects, groupProjects...) } + + projects = append(projects, gp...) } + // Filter unique projects -- there may be duplicates between groups, other groups and projects + projects = filterUniqueProjects(projects) + return } -func (s *service) gatherProjects(projectPaths []string) (projects []gitlab.Project, warn error) { - for _, projectPath := range projectPaths { - log.Info().Str("project", projectPath).Msg("Getting project") - p, err := s.getProject(projectPath) - if err != nil { - log.Error().Err(err).Str("project", projectPath).Msg("Failed to fetch project") - err = errors.Join(fmt.Errorf("failed to fetch project %v", projectPath), err) - warn = errors.Join(err, warn) - continue +// This function receives a path that could either be a gitlab group, or a gitlab path. +// It first tries to get the path as a group. +// +// If it succeeds then it returns all projects of that group & its subgroups. +// If it fails then it tries to get the path as a project. +func (s *service) getProjectsFromGroupOrProject(path string) (projects []gitlab.Project, warn error, err error) { + gp, gpwarn, gperr := s.listGroupProjects(path) + if gperr != nil { + log.Debug().Str("path", path).Msg("failed to fetch as group. trying as project") + p, _, perr := s.client.GetProject(path, &gitlab.GetProjectOptions{}) + if perr != nil { + return nil, nil, errors.Join(fmt.Errorf("failed to get group %v", path), gperr) } - projects = append(projects, *p) + return []gitlab.Project{*p}, nil, nil } - return + return gp, gpwarn, nil } // getVulnerabilityIssue returns the vulnerability issue for the given project @@ -235,8 +176,8 @@ func (s *service) getVulnerabilityIssue(project gitlab.Project) (issue *gitlab.I } // listGroupProjects returns the list of projects for the given group ID -func (s *service) listGroupProjects(groupID int) (projects []gitlab.Project, warn error, err error) { - projectPtrs, response, err := s.client.ListGroupProjects(groupID, +func (s *service) listGroupProjects(path string) (projects []gitlab.Project, warn error, err error) { + projectPtrs, response, err := s.client.ListGroupProjects(path, &gitlab.ListGroupProjectsOptions{ Archived: gitlab.Ptr(false), Simple: gitlab.Ptr(true), @@ -252,11 +193,11 @@ func (s *service) listGroupProjects(groupID int) (projects []gitlab.Project, war projects, errCount := dereferenceProjectsPointers(projectPtrs) if errCount > 0 { - log.Warn().Int("groupID", groupID).Int("count", errCount).Msg("Found nil projects, skipping them.") + log.Warn().Str("path", path).Int("count", errCount).Msg("Found nil projects, skipping them.") } if response.TotalPages > 1 { - nextProjects, lgwarn := s.listGroupNextProjects(groupID, response.TotalPages) + nextProjects, lgwarn := s.listGroupNextProjects(path, response.TotalPages) if lgwarn != nil { lgwarn = errors.Join(errors.New("errors occured when fetching next pages"), lgwarn) warn = errors.Join(lgwarn, warn) @@ -278,7 +219,7 @@ func ToChan[T any](s []T) <-chan T { } // listGroupNextProjects returns the list of projects for the given group ID from the next pages -func (s *service) listGroupNextProjects(groupID int, totalPages int) (projects []gitlab.Project, warn error) { +func (s *service) listGroupNextProjects(path string, totalPages int) (projects []gitlab.Project, warn error) { var wg sync.WaitGroup nextProjectsChan := make(chan []gitlab.Project, totalPages) warnChan := make(chan error, totalPages) @@ -287,8 +228,8 @@ func (s *service) listGroupNextProjects(groupID int, totalPages int) (projects [ go func(reportsChan chan<- []gitlab.Project) { defer wg.Done() - log.Info().Int("groupID", groupID).Int("page", p).Msg("Fetching projects of next page") - projectPtrs, _, err := s.client.ListGroupProjects(groupID, + log.Info().Str("path", path).Int("page", p).Msg("Fetching projects of next page") + projectPtrs, _, err := s.client.ListGroupProjects(path, &gitlab.ListGroupProjectsOptions{ Archived: gitlab.Ptr(false), Simple: gitlab.Ptr(true), @@ -299,13 +240,13 @@ func (s *service) listGroupNextProjects(groupID int, totalPages int) (projects [ }, }) if err != nil { - log.Error().Err(err).Int("groupID", groupID).Int("page", p).Msg("Failed to fetch projects of next page, these projects will be missing.") + log.Error().Err(err).Str("path", path).Int("page", p).Msg("Failed to fetch projects of next page, these projects will be missing.") warnChan <- err } projects, errCount := dereferenceProjectsPointers(projectPtrs) if errCount > 0 { - log.Warn().Int("groupID", groupID).Int("page", p).Int("count", errCount).Msg("Found nil projects, skipping them.") + log.Warn().Str("path", path).Int("page", p).Int("count", errCount).Msg("Found nil projects, skipping them.") } nextProjectsChan <- projects diff --git a/internal/gitlab/gitlab_test.go b/internal/gitlab/gitlab_test.go index 07795c2..37e7e48 100644 --- a/internal/gitlab/gitlab_test.go +++ b/internal/gitlab/gitlab_test.go @@ -1,6 +1,7 @@ package gitlab import ( + "errors" "testing" "github.com/stretchr/testify/assert" @@ -17,12 +18,11 @@ func TestNewService(t *testing.T) { func TestGetProjectListWithTopLevelGroup(t *testing.T) { mockClient := mockClient{} - mockClient.On("ListGroups", mock.Anything, mock.Anything).Return([]*gitlab.Group{{ID: 1, FullPath: "group"}}, nil, nil) - mockClient.On("ListGroupProjects", 1, mock.Anything, mock.Anything).Return([]*gitlab.Project{{Name: "Hello World"}}, &gitlab.Response{}, nil) + mockClient.On("ListGroupProjects", "group", mock.Anything, mock.Anything).Return([]*gitlab.Project{{Name: "Hello World"}}, &gitlab.Response{}, nil) svc := service{&mockClient} - projects, err := svc.GetProjectList([]string{"group"}, []string{}) + projects, err := svc.GetProjectList([]string{"group"}) assert.Nil(t, err) assert.NotEmpty(t, projects) @@ -32,12 +32,11 @@ func TestGetProjectListWithTopLevelGroup(t *testing.T) { func TestGetProjectListWithSubGroup(t *testing.T) { mockClient := mockClient{} - mockClient.On("ListGroups", mock.Anything, mock.Anything).Return([]*gitlab.Group{{ID: 1, FullPath: "group/subgroup"}}, nil, nil) - mockClient.On("ListGroupProjects", 1, mock.Anything, mock.Anything).Return([]*gitlab.Project{{Name: "Hello World"}}, &gitlab.Response{}, nil) + mockClient.On("ListGroupProjects", "group/subgroup", mock.Anything, mock.Anything).Return([]*gitlab.Project{{Name: "Hello World"}}, &gitlab.Response{}, nil) svc := service{&mockClient} - projects, err := svc.GetProjectList([]string{"group/subgroup"}, []string{}) + projects, err := svc.GetProjectList([]string{"group/subgroup"}) assert.Nil(t, err) assert.NotEmpty(t, projects) @@ -47,12 +46,12 @@ func TestGetProjectListWithSubGroup(t *testing.T) { func TestGetProjectListWithProjects(t *testing.T) { mockClient := mockClient{} - mockClient.On("ListGroups", mock.Anything, mock.Anything).Return([]*gitlab.Group{{ID: 1, FullPath: "group/subgroup"}}, nil, nil) - mockClient.On("ListGroupProjects", 1, mock.Anything, mock.Anything).Return([]*gitlab.Project{{Name: "Hello World", PathWithNamespace: "group/subgroup/project"}}, &gitlab.Response{}, nil) + mockClient.On("ListGroupProjects", "group/subgroup/project", mock.Anything, mock.Anything).Return([]*gitlab.Project{}, &gitlab.Response{}, errors.New("no group")) + mockClient.On("GetProject", "group/subgroup/project", mock.Anything, mock.Anything).Return(&gitlab.Project{Name: "Hello World", PathWithNamespace: "group/subgroup/project"}, &gitlab.Response{}, nil) svc := service{&mockClient} - projects, err := svc.GetProjectList([]string{}, []string{"group/subgroup/project"}) + projects, err := svc.GetProjectList([]string{"group/subgroup/project"}) assert.Nil(t, err) assert.NotEmpty(t, projects) @@ -63,18 +62,16 @@ func TestGetProjectListWithProjects(t *testing.T) { func TestGetProjectListWithGroupAndProjects(t *testing.T) { project1 := &gitlab.Project{ID: 1, PathWithNamespace: "group/subgroup/project"} project2 := &gitlab.Project{ID: 2, PathWithNamespace: "group/project"} - group1 := &gitlab.Group{ID: 1, FullPath: "group"} - group2 := &gitlab.Group{ID: 2, FullPath: "group/subgroup"} mockClient := mockClient{} - mockClient.On("ListGroups", &gitlab.ListGroupsOptions{Search: gitlab.Ptr("group")}, mock.Anything).Return([]*gitlab.Group{group1}, nil, nil) - mockClient.On("ListGroups", &gitlab.ListGroupsOptions{Search: gitlab.Ptr("group/subgroup")}, mock.Anything).Return([]*gitlab.Group{group2}, nil, nil) - mockClient.On("ListGroupProjects", 1, mock.Anything, mock.Anything).Return([]*gitlab.Project{project1, project2}, &gitlab.Response{}, nil) - mockClient.On("ListGroupProjects", 2, mock.Anything, mock.Anything).Return([]*gitlab.Project{project1}, &gitlab.Response{}, nil) + mockClient.On("ListGroupProjects", "group", mock.Anything, mock.Anything).Return([]*gitlab.Project{project1, project2}, &gitlab.Response{}, nil) + mockClient.On("ListGroupProjects", "group/subgroup", mock.Anything, mock.Anything).Return([]*gitlab.Project{project1}, &gitlab.Response{}, nil) + mockClient.On("ListGroupProjects", project1.PathWithNamespace, mock.Anything, mock.Anything).Return([]*gitlab.Project{}, &gitlab.Response{}, errors.New("no group")) + mockClient.On("GetProject", project1.PathWithNamespace, mock.Anything, mock.Anything).Return(project1, &gitlab.Response{}, nil) svc := service{&mockClient} - projects, err := svc.GetProjectList([]string{group1.FullPath}, []string{project1.PathWithNamespace}) + projects, err := svc.GetProjectList([]string{"group", "group/subgroup", project1.PathWithNamespace}) assert.Nil(t, err) assert.NotEmpty(t, projects) @@ -85,13 +82,11 @@ func TestGetProjectListWithGroupAndProjects(t *testing.T) { } func TestGetProjectListWithNextPage(t *testing.T) { - group := &gitlab.Group{ID: 1, FullPath: "group/subgroup"} project1 := &gitlab.Project{ID: 1} project2 := &gitlab.Project{ID: 2} mockClient := mockClient{} - mockClient.On("ListGroups", mock.Anything, mock.Anything).Return([]*gitlab.Group{group}, nil, nil) - mockClient.On("ListGroupProjects", mock.Anything, &gitlab.ListGroupProjectsOptions{ + mockClient.On("ListGroupProjects", "group/subgroup", &gitlab.ListGroupProjectsOptions{ Archived: gitlab.Ptr(false), Simple: gitlab.Ptr(true), IncludeSubGroups: gitlab.Ptr(true), @@ -100,7 +95,7 @@ func TestGetProjectListWithNextPage(t *testing.T) { Page: 1, }, }, mock.Anything).Return([]*gitlab.Project{project1}, &gitlab.Response{NextPage: 2, TotalPages: 2}, nil) - mockClient.On("ListGroupProjects", mock.Anything, &gitlab.ListGroupProjectsOptions{ + mockClient.On("ListGroupProjects", "group/subgroup", &gitlab.ListGroupProjectsOptions{ Archived: gitlab.Ptr(false), Simple: gitlab.Ptr(true), IncludeSubGroups: gitlab.Ptr(true), @@ -112,7 +107,7 @@ func TestGetProjectListWithNextPage(t *testing.T) { svc := service{&mockClient} - projects, err := svc.GetProjectList([]string{group.FullPath}, []string{}) + projects, err := svc.GetProjectList([]string{"group/subgroup"}) assert.Nil(t, err) assert.Len(t, projects, 2) @@ -208,17 +203,17 @@ type mockClient struct { mock.Mock } -func (c *mockClient) ListGroups(opt *gitlab.ListGroupsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Group, *gitlab.Response, error) { - args := c.Called(opt, options) +func (c *mockClient) GetProject(pid interface{}, opt *gitlab.GetProjectOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Project, *gitlab.Response, error) { + args := c.Called(pid, opt, options) var r *gitlab.Response if resp := args.Get(1); resp != nil { r = args.Get(1).(*gitlab.Response) } - return args.Get(0).([]*gitlab.Group), r, args.Error(2) + return args.Get(0).(*gitlab.Project), r, args.Error(2) } -func (c *mockClient) ListGroupProjects(groupId int, opt *gitlab.ListGroupProjectsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Project, *gitlab.Response, error) { - args := c.Called(groupId, opt, options) +func (c *mockClient) ListGroupProjects(gid interface{}, opt *gitlab.ListGroupProjectsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.Project, *gitlab.Response, error) { + args := c.Called(gid, opt, options) var r *gitlab.Response if resp := args.Get(1); resp != nil { r = args.Get(1).(*gitlab.Response) diff --git a/internal/patrol/patrol.go b/internal/patrol/patrol.go index 3aa1c5c..291759a 100644 --- a/internal/patrol/patrol.go +++ b/internal/patrol/patrol.go @@ -14,6 +14,7 @@ import ( "sort" "sync" + "github.com/elliotchance/pie/v2" "github.com/rs/zerolog/log" gogitlab "github.com/xanzy/go-gitlab" ) @@ -21,10 +22,33 @@ import ( const tempScanDir = "tmp_scans" const projectConfigFileName = "sheriff.toml" +type PlatformType string + +const ( + Gitlab PlatformType = "gitlab" + Github PlatformType = "github" +) + +type ProjectLocation struct { + Type PlatformType + Path string +} + +// PatrolArgs is a struct to store the arguments for the main patrol function. +type PatrolArgs struct { + Locations []ProjectLocation + ReportToEmails []string + ReportToSlackChannel string + ReportToIssue bool + EnableProjectReportTo bool + SilentReport bool + Verbose bool +} + // securityPatroller is the interface of the main security scanner service of this tool. type securityPatroller interface { // Scans the given Gitlab groups and projects, creates and publishes the necessary reports - Patrol(grouiPaths []string, projectPaths []string, gitlabIssue bool, slackChannel string, reportProjectSlack bool, silentReport bool, verbose bool) (warn error, err error) + Patrol(args PatrolArgs) (warn error, err error) } // sheriffService is the implementation of the SecurityPatroller interface. @@ -48,8 +72,8 @@ func New(gitlabService gitlab.IService, slackService slack.IService, gitService } // Patrol scans the given Gitlab groups and projects, creates and publishes the necessary reports. -func (s *sheriffService) Patrol(groupPaths []string, projectPaths []string, gitlabIssue bool, slackChannel string, reportProjectSlack bool, silentReport bool, verbose bool) (warn error, err error) { - scanReports, swarn, err := s.scanAndGetReports(groupPaths, projectPaths) +func (s *sheriffService) Patrol(args PatrolArgs) (warn error, err error) { + scanReports, swarn, err := s.scanAndGetReports(args.Locations) if err != nil { return nil, errors.Join(errors.New("failed to scan projects"), err) } @@ -63,7 +87,7 @@ func (s *sheriffService) Patrol(groupPaths []string, projectPaths []string, gitl return swarn, nil } - if gitlabIssue { + if args.ReportToIssue { log.Info().Msg("Creating issue in affected projects") if gwarn := publish.PublishAsGitlabIssues(scanReports, s.gitlabService); gwarn != nil { gwarn = errors.Join(errors.New("errors occured when creating issues"), gwarn) @@ -73,17 +97,18 @@ func (s *sheriffService) Patrol(groupPaths []string, projectPaths []string, gitl } if s.slackService != nil { - if slackChannel != "" { - log.Info().Str("slackChannel", slackChannel).Msg("Posting report to slack channel") + if args.ReportToSlackChannel != "" { + log.Info().Str("slackChannel", args.ReportToSlackChannel).Msg("Posting report to slack channel") - if err := publish.PublishAsGeneralSlackMessage(slackChannel, scanReports, groupPaths, projectPaths, s.slackService); err != nil { + paths := pie.Map(args.Locations, func(v ProjectLocation) string { return v.Path }) + if err := publish.PublishAsGeneralSlackMessage(args.ReportToSlackChannel, scanReports, paths, s.slackService); err != nil { log.Error().Err(err).Msg("Failed to post slack report") err = errors.Join(errors.New("failed to post slack report"), err) warn = errors.Join(err, warn) } } - if reportProjectSlack { + if args.EnableProjectReportTo { log.Info().Msg("Posting report to project slack channel") if swarn := publish.PublishAsSpecificChannelSlackMessage(scanReports, s.slackService); swarn != nil { swarn = errors.Join(errors.New("errors occured when posting to project slack channel"), swarn) @@ -93,12 +118,12 @@ func (s *sheriffService) Patrol(groupPaths []string, projectPaths []string, gitl } } - publish.PublishToConsole(scanReports, silentReport) + publish.PublishToConsole(scanReports, args.SilentReport) return warn, nil } -func (s *sheriffService) scanAndGetReports(groupPaths []string, projectPaths []string) (reports []scanner.Report, warn error, err error) { +func (s *sheriffService) scanAndGetReports(locations []ProjectLocation) (reports []scanner.Report, warn error, err error) { // Create a temporary directory to store the scans err = os.MkdirAll(tempScanDir, os.ModePerm) if err != nil { @@ -106,9 +131,14 @@ func (s *sheriffService) scanAndGetReports(groupPaths []string, projectPaths []s } defer os.RemoveAll(tempScanDir) log.Info().Str("path", tempScanDir).Msg("Created temporary directory") - log.Info().Strs("groups", groupPaths).Strs("projects", projectPaths).Msg("Getting the list of projects to scan") - projects, pwarn := s.gitlabService.GetProjectList(groupPaths, projectPaths) + gitlabLocs := pie.Map( + pie.Filter(locations, func(v ProjectLocation) bool { return v.Type == Gitlab }), + func(v ProjectLocation) string { return v.Path }, + ) + log.Info().Strs("locations", gitlabLocs).Msg("Getting the list of projects to scan") + + projects, pwarn := s.gitlabService.GetProjectList(gitlabLocs) if pwarn != nil { pwarn = errors.Join(errors.New("errors occured when getting project list"), pwarn) warn = errors.Join(pwarn, warn) diff --git a/internal/patrol/patrol_test.go b/internal/patrol/patrol_test.go index b4ffe25..b644b39 100644 --- a/internal/patrol/patrol_test.go +++ b/internal/patrol/patrol_test.go @@ -18,7 +18,7 @@ func TestNewService(t *testing.T) { func TestScanNoProjects(t *testing.T) { mockGitlabService := &mockGitlabService{} - mockGitlabService.On("GetProjectList", []string{"group/to/scan"}, []string{}).Return([]gitlab.Project{}, nil) + mockGitlabService.On("GetProjectList", []string{"group/to/scan"}).Return([]gitlab.Project{}, nil) mockSlackService := &mockSlackService{} @@ -30,7 +30,15 @@ func TestScanNoProjects(t *testing.T) { svc := New(mockGitlabService, mockSlackService, mockGitService, mockOSVService) - warn, err := svc.Patrol([]string{"group/to/scan"}, []string{}, true, "channel", true, false, false) + warn, err := svc.Patrol(PatrolArgs{ + Locations: []ProjectLocation{{Type: Gitlab, Path: "group/to/scan"}}, + ReportToEmails: []string{}, + ReportToSlackChannel: "channel", + ReportToIssue: true, + EnableProjectReportTo: true, + Verbose: true, + SilentReport: false, + }) assert.Nil(t, err) assert.Nil(t, warn) @@ -40,7 +48,7 @@ func TestScanNoProjects(t *testing.T) { func TestScanNonVulnerableProject(t *testing.T) { mockGitlabService := &mockGitlabService{} - mockGitlabService.On("GetProjectList", []string{"group/to/scan"}, []string{}).Return([]gitlab.Project{{Name: "Hello World", HTTPURLToRepo: "https://gitlab.com/group/to/scan.git"}}, nil) + mockGitlabService.On("GetProjectList", []string{"group/to/scan"}).Return([]gitlab.Project{{Name: "Hello World", HTTPURLToRepo: "https://gitlab.com/group/to/scan.git"}}, nil) mockGitlabService.On("CloseVulnerabilityIssue", mock.Anything).Return(nil) mockSlackService := &mockSlackService{} @@ -55,7 +63,15 @@ func TestScanNonVulnerableProject(t *testing.T) { svc := New(mockGitlabService, mockSlackService, mockGitService, mockOSVService) - warn, err := svc.Patrol([]string{"group/to/scan"}, []string{}, true, "channel", true, false, false) + warn, err := svc.Patrol(PatrolArgs{ + Locations: []ProjectLocation{{Type: Gitlab, Path: "group/to/scan"}}, + ReportToEmails: []string{}, + ReportToSlackChannel: "channel", + ReportToIssue: true, + EnableProjectReportTo: true, + Verbose: true, + SilentReport: false, + }) assert.Nil(t, err) assert.Nil(t, warn) @@ -65,7 +81,7 @@ func TestScanNonVulnerableProject(t *testing.T) { func TestScanVulnerableProject(t *testing.T) { mockGitlabService := &mockGitlabService{} - mockGitlabService.On("GetProjectList", []string{"group/to/scan"}, []string{}).Return([]gitlab.Project{{Name: "Hello World", HTTPURLToRepo: "https://gitlab.com/group/to/scan.git"}}, nil) + mockGitlabService.On("GetProjectList", []string{"group/to/scan"}).Return([]gitlab.Project{{Name: "Hello World", HTTPURLToRepo: "https://gitlab.com/group/to/scan.git"}}, nil) mockGitlabService.On("OpenVulnerabilityIssue", mock.Anything, mock.Anything).Return(&gitlab.Issue{}, nil) mockSlackService := &mockSlackService{} @@ -88,7 +104,15 @@ func TestScanVulnerableProject(t *testing.T) { svc := New(mockGitlabService, mockSlackService, mockGitService, mockOSVService) - warn, err := svc.Patrol([]string{"group/to/scan"}, []string{}, true, "channel", true, false, false) + warn, err := svc.Patrol(PatrolArgs{ + Locations: []ProjectLocation{{Type: Gitlab, Path: "group/to/scan"}}, + ReportToEmails: []string{}, + ReportToSlackChannel: "channel", + ReportToIssue: true, + EnableProjectReportTo: true, + Verbose: true, + SilentReport: false, + }) assert.Nil(t, err) assert.Nil(t, warn) @@ -134,8 +158,8 @@ type mockGitlabService struct { mock.Mock } -func (c *mockGitlabService) GetProjectList(groupPaths []string, projectPaths []string) ([]gitlab.Project, error) { - args := c.Called(groupPaths, projectPaths) +func (c *mockGitlabService) GetProjectList(paths []string) ([]gitlab.Project, error) { + args := c.Called(paths) return args.Get(0).([]gitlab.Project), args.Error(1) } diff --git a/internal/publish/to_gitlab_test.go b/internal/publish/to_gitlab_test.go index ebafe09..d1b86ca 100644 --- a/internal/publish/to_gitlab_test.go +++ b/internal/publish/to_gitlab_test.go @@ -206,8 +206,8 @@ type mockGitlabService struct { mock.Mock } -func (c *mockGitlabService) GetProjectList(groupPaths []string, projectPaths []string) ([]gitlab.Project, error) { - args := c.Called(groupPaths, projectPaths) +func (c *mockGitlabService) GetProjectList(paths []string) ([]gitlab.Project, error) { + args := c.Called(paths) return args.Get(0).([]gitlab.Project), args.Error(1) } diff --git a/internal/publish/to_slack.go b/internal/publish/to_slack.go index 82f6eb0..a04275b 100644 --- a/internal/publish/to_slack.go +++ b/internal/publish/to_slack.go @@ -16,10 +16,10 @@ import ( ) // PublishAsGeneralSlackMessage publishes a report of the vulnerabilities scanned to a slack channel -func PublishAsGeneralSlackMessage(channelName string, reports []scanner.Report, groups []string, projects []string, s slack.IService) (err error) { +func PublishAsGeneralSlackMessage(channelName string, reports []scanner.Report, paths []string, s slack.IService) (err error) { vulnerableReportsBySeverityKind := groupVulnReportsByMaxSeverityKind(reports) - summary := formatSummary(vulnerableReportsBySeverityKind, len(reports), groups, projects) + summary := formatSummary(vulnerableReportsBySeverityKind, len(reports), paths) ts, err := s.PostMessage(channelName, summary...) if err != nil { @@ -132,7 +132,7 @@ func formatSubtitleList(entity string, list []string) *goslack.ContextBlock { } // formatSummary creates a message block with a summary of the reports -func formatSummary(reportsBySeverityKind map[scanner.SeverityScoreKind][]scanner.Report, totalReports int, groups []string, projects []string) []goslack.MsgOption { +func formatSummary(reportsBySeverityKind map[scanner.SeverityScoreKind][]scanner.Report, totalReports int, paths []string) []goslack.MsgOption { title := goslack.NewHeaderBlock( goslack.NewTextBlockObject( "plain_text", @@ -140,8 +140,7 @@ func formatSummary(reportsBySeverityKind map[scanner.SeverityScoreKind][]scanner true, false, ), ) - subtitleGroups := formatSubtitleList("groups", groups) - subtitleProjects := formatSubtitleList("specific projects", projects) + subtitleGroups := formatSubtitleList("urls", paths) subtitleCount := goslack.NewContextBlock("subtitleCount", goslack.NewTextBlockObject("mrkdwn", fmt.Sprintf("Total projects scanned: %v", totalReports), false, false)) counts := pie.Map(severityScoreOrder, func(kind scanner.SeverityScoreKind) *goslack.TextBlockObject { @@ -161,7 +160,6 @@ func formatSummary(reportsBySeverityKind map[scanner.SeverityScoreKind][]scanner blocks := []goslack.Block{ title, subtitleGroups, - subtitleProjects, subtitleCount, countsTitle, countsBlock, diff --git a/internal/publish/to_slack_test.go b/internal/publish/to_slack_test.go index 73eec20..65d4523 100644 --- a/internal/publish/to_slack_test.go +++ b/internal/publish/to_slack_test.go @@ -24,7 +24,7 @@ func TestPublishAsGeneralSlackMessage(t *testing.T) { }, } - err := PublishAsGeneralSlackMessage("channel", report, []string{"path/to/group"}, []string{"path/to/project"}, mockSlackService) + err := PublishAsGeneralSlackMessage("channel", report, []string{"path/to/group", "path/to/project"}, mockSlackService) assert.Nil(t, err) mockSlackService.AssertExpectations(t) @@ -70,7 +70,7 @@ func TestFormatSummary(t *testing.T) { }, } - msgOpts := formatSummary(groupVulnReportsByMaxSeverityKind(report), len(report), []string{"path/to/group"}, []string{"path/to/project"}) + msgOpts := formatSummary(groupVulnReportsByMaxSeverityKind(report), len(report), []string{"path/to/group", "path/to/project"}) assert.NotNil(t, msgOpts) assert.Len(t, msgOpts, 1) diff --git a/internal/slack/slack.go b/internal/slack/slack.go index 65ae300..62bd79c 100644 --- a/internal/slack/slack.go +++ b/internal/slack/slack.go @@ -14,18 +14,17 @@ type IService interface { } type service struct { - client iclient - isPublicChannelsEnabled bool + client iclient } // New creates a new Slack service -func New(token string, isPublicChannelsEnabled bool, debug bool) (IService, error) { +func New(token string, debug bool) (IService, error) { slackClient := slack.New(token, slack.OptionDebug(debug)) if slackClient == nil { return nil, errors.New("failed to create slack client") } - s := service{&client{client: slackClient}, isPublicChannelsEnabled} + s := service{&client{client: slackClient}} return &s, nil } @@ -52,10 +51,7 @@ func (s *service) PostMessage(channelName string, options ...slack.MsgOption) (t func (s *service) findSlackChannel(channelName string) (channel *slack.Channel, err error) { var nextCursor string var channels []slack.Channel - var channelTypes = []string{"private_channel"} - if s.isPublicChannelsEnabled { - channelTypes = append(channelTypes, "public_channel") - } + var channelTypes = []string{"private_channel", "public_channel"} for { if channels, nextCursor, err = s.client.GetConversations(&slack.GetConversationsParameters{ diff --git a/internal/slack/slack_test.go b/internal/slack/slack_test.go index dd97fa5..c4a1cf5 100644 --- a/internal/slack/slack_test.go +++ b/internal/slack/slack_test.go @@ -9,7 +9,7 @@ import ( ) func TestNewService(t *testing.T) { - s, err := New("token", false, false) + s, err := New("token", false) assert.Nil(t, err) assert.NotNil(t, s) @@ -35,7 +35,7 @@ func TestPostMessage(t *testing.T) { ) mockClient.On("PostMessage", channelID, mock.Anything).Return("", "", nil) - svc := service{&mockClient, false} + svc := service{&mockClient} _, err := svc.PostMessage(channelName, message) @@ -47,43 +47,33 @@ func TestFindSlackChannel(t *testing.T) { channelID := "1234" channelName := "random channel" - testCases := []struct { - isPublicChannelsEnabled bool - want []string - }{ - {true, []string{"private_channel", "public_channel"}}, - {false, []string{"private_channel"}}, - } - - for _, tc := range testCases { - mockClient := mockClient{} - mockClient.On("GetConversations", &slack.GetConversationsParameters{ - ExcludeArchived: true, - Cursor: "", - Types: tc.want, - Limit: 1000, - }).Return( - []slack.Channel{ - { - GroupConversation: slack.GroupConversation{ - Conversation: slack.Conversation{ID: channelID}, - Name: channelName, - }, + mockClient := mockClient{} + mockClient.On("GetConversations", &slack.GetConversationsParameters{ + ExcludeArchived: true, + Cursor: "", + Types: []string{"private_channel", "public_channel"}, + Limit: 1000, + }).Return( + []slack.Channel{ + { + GroupConversation: slack.GroupConversation{ + Conversation: slack.Conversation{ID: channelID}, + Name: channelName, }, }, - "", - nil, - ) + }, + "", + nil, + ) - svc := service{&mockClient, tc.isPublicChannelsEnabled} + svc := service{&mockClient} - channel, err := svc.findSlackChannel(channelName) + channel, err := svc.findSlackChannel(channelName) - assert.Nil(t, err) - assert.NotNil(t, channel) - assert.Equal(t, channelID, channel.ID) + assert.Nil(t, err) + assert.NotNil(t, channel) + assert.Equal(t, channelID, channel.ID) - } } type mockClient struct {