Skip to content
This repository was archived by the owner on Jul 15, 2024. It is now read-only.

Commit ea586e4

Browse files
author
Fardin Khanjani
committed
feat: This commit adds pull request support to SCM generator so the generator
can create ArgoCD apps for PRs as well. Fixes #466 Signed-off-by: Fardin Khanjani <[email protected]>
1 parent e900eab commit ea586e4

File tree

9 files changed

+241
-31
lines changed

9 files changed

+241
-31
lines changed

api/v1alpha1/applicationset_types.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ type SCMProviderGeneratorGithub struct {
314314
TokenRef *SecretRef `json:"tokenRef,omitempty"`
315315
// Scan all branches instead of just the default branch.
316316
AllBranches bool `json:"allBranches,omitempty"`
317+
// Scan all pull requests
318+
AllPullRequests bool `json:"allPullRequests,omitempty"`
317319
}
318320

319321
// SCMProviderGeneratorGitlab defines a connection info specific to Gitlab.
@@ -328,6 +330,8 @@ type SCMProviderGeneratorGitlab struct {
328330
TokenRef *SecretRef `json:"tokenRef,omitempty"`
329331
// Scan all branches instead of just the default branch.
330332
AllBranches bool `json:"allBranches,omitempty"`
333+
// Scan all pull requests
334+
AllPullRequests bool `json:"allPullRequests,omitempty"`
331335
}
332336

333337
// SCMProviderGeneratorFilter is a single repository filter.
@@ -342,6 +346,10 @@ type SCMProviderGeneratorFilter struct {
342346
LabelMatch *string `json:"labelMatch,omitempty"`
343347
// A regex which must match the branch name.
344348
BranchMatch *string `json:"branchMatch,omitempty"`
349+
// A regex which must match the pull request tile.
350+
PullRequestTitleMatch *string `json:"pullRequestTitleMatch,omitempty"`
351+
// A regex which must match at least one pull request label.
352+
PullRequestLabelMatch *string `json:"pullRequestLabelMatch,omitempty"`
345353
}
346354

347355
// PullRequestGenerator defines a generator that scrapes a PullRequest API to find candidate pull requests.

pkg/generators/scm_provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
6464
if err != nil {
6565
return nil, fmt.Errorf("error fetching Github token: %v", err)
6666
}
67-
provider, err = scm_provider.NewGithubProvider(ctx, providerConfig.Github.Organization, token, providerConfig.Github.API, providerConfig.Github.AllBranches)
67+
provider, err = scm_provider.NewGithubProvider(ctx, providerConfig.Github.Organization, token, providerConfig.Github.API, providerConfig.Github.AllBranches, providerConfig.Github.AllPullRequests)
6868
if err != nil {
6969
return nil, fmt.Errorf("error initializing Github service: %v", err)
7070
}
@@ -73,7 +73,7 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
7373
if err != nil {
7474
return nil, fmt.Errorf("error fetching Gitlab token: %v", err)
7575
}
76-
provider, err = scm_provider.NewGitlabProvider(ctx, providerConfig.Gitlab.Group, token, providerConfig.Gitlab.API, providerConfig.Gitlab.AllBranches, providerConfig.Gitlab.IncludeSubgroups)
76+
provider, err = scm_provider.NewGitlabProvider(ctx, providerConfig.Gitlab.Group, token, providerConfig.Gitlab.API, providerConfig.Gitlab.AllBranches, providerConfig.Gitlab.IncludeSubgroups, providerConfig.Gitlab.AllPullRequests)
7777
if err != nil {
7878
return nil, fmt.Errorf("error initializing Gitlab service: %v", err)
7979
}

pkg/services/scm_provider/github.go

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ import (
1111
)
1212

1313
type GithubProvider struct {
14-
client *github.Client
15-
organization string
16-
allBranches bool
14+
client *github.Client
15+
organization string
16+
allBranches bool
17+
allPullRequests bool
1718
}
1819

1920
var _ SCMProviderService = &GithubProvider{}
2021

21-
func NewGithubProvider(ctx context.Context, organization string, token string, url string, allBranches bool) (*GithubProvider, error) {
22+
func NewGithubProvider(ctx context.Context, organization string, token string, url string, allBranches bool, allPullRequests bool) (*GithubProvider, error) {
2223
var ts oauth2.TokenSource
2324
// Undocumented environment variable to set a default token, to be used in testing to dodge anonymous rate limits.
2425
if token == "" {
@@ -40,7 +41,7 @@ func NewGithubProvider(ctx context.Context, organization string, token string, u
4041
return nil, err
4142
}
4243
}
43-
return &GithubProvider{client: client, organization: organization, allBranches: allBranches}, nil
44+
return &GithubProvider{client: client, organization: organization, allBranches: allBranches, allPullRequests: allPullRequests}, nil
4445
}
4546

4647
func (g *GithubProvider) GetBranches(ctx context.Context, repo *Repository) ([]*Repository, error) {
@@ -64,6 +65,32 @@ func (g *GithubProvider) GetBranches(ctx context.Context, repo *Repository) ([]*
6465
return repos, nil
6566
}
6667

68+
func (g *GithubProvider) GetPullRequests(ctx context.Context, repo *Repository) ([]*Repository, error) {
69+
repos := []*Repository{}
70+
pullRequests, err := g.listPullRequests(ctx, repo)
71+
if err != nil {
72+
return nil, fmt.Errorf("error listing pull requests for %s/%s: %v", repo.Organization, repo.Repository, err)
73+
}
74+
75+
// go-github's PullRequest type does not have a GetLabel() function.
76+
var labels []string
77+
for _, pullRequest := range pullRequests {
78+
for _, label := range pullRequest.Labels {
79+
labels = append(labels, label.GetName())
80+
}
81+
repos = append(repos, &Repository{
82+
Organization: repo.Organization,
83+
Repository: repo.Repository,
84+
URL: repo.URL,
85+
Branch: pullRequest.GetTitle(),
86+
SHA: pullRequest.GetHead().GetSHA(),
87+
Labels: labels,
88+
RepositoryId: repo.RepositoryId,
89+
})
90+
}
91+
return repos, nil
92+
}
93+
6794
func (g *GithubProvider) ListRepos(ctx context.Context, cloneProtocol string) ([]*Repository, error) {
6895
opt := &github.RepositoryListByOrgOptions{
6996
ListOptions: github.ListOptions{PerPage: 100},
@@ -104,7 +131,7 @@ func (g *GithubProvider) ListRepos(ctx context.Context, cloneProtocol string) ([
104131

105132
func (g *GithubProvider) RepoHasPath(ctx context.Context, repo *Repository, path string) (bool, error) {
106133
_, _, resp, err := g.client.Repositories.GetContents(ctx, repo.Organization, repo.Repository, path, &github.RepositoryContentGetOptions{
107-
Ref: repo.Branch,
134+
Ref: repo.SHA,
108135
})
109136
// 404s are not an error here, just a normal false.
110137
if resp != nil && resp.StatusCode == 404 {
@@ -153,3 +180,33 @@ func (g *GithubProvider) listBranches(ctx context.Context, repo *Repository) ([]
153180
}
154181
return branches, nil
155182
}
183+
184+
func (g *GithubProvider) listPullRequests(ctx context.Context, repo *Repository) ([]github.PullRequest, error) {
185+
186+
if !g.allPullRequests {
187+
return nil, nil
188+
}
189+
190+
opt := &github.PullRequestListOptions{
191+
ListOptions: github.ListOptions{PerPage: 100},
192+
}
193+
194+
githubPullRequests := []github.PullRequest{}
195+
196+
for {
197+
allPullRequests, resp, err := g.client.PullRequests.List(ctx, repo.Organization, repo.Repository, opt)
198+
if err != nil {
199+
return nil, err
200+
}
201+
202+
for _, pr := range allPullRequests {
203+
githubPullRequests = append(githubPullRequests, *pr)
204+
}
205+
206+
if resp.NextPage == 0 {
207+
break
208+
}
209+
opt.Page = resp.NextPage
210+
}
211+
return githubPullRequests, nil
212+
}

pkg/services/scm_provider/github_test.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ func checkRateLimit(t *testing.T, err error) {
3636

3737
func TestGithubListRepos(t *testing.T) {
3838
cases := []struct {
39-
name, proto, url string
40-
hasError, allBranches bool
41-
branches []string
42-
filters []v1alpha1.SCMProviderGeneratorFilter
39+
name, proto, url string
40+
hasError, allBranches, allPullRequests bool
41+
branches []string
42+
filters []v1alpha1.SCMProviderGeneratorFilter
4343
}{
4444
{
4545
name: "blank protocol",
@@ -67,11 +67,17 @@ func TestGithubListRepos(t *testing.T) {
6767
url: "[email protected]:argoproj/applicationset.git",
6868
branches: []string{"master", "release-0.1.0"},
6969
},
70+
{
71+
name: "all pull requests",
72+
allPullRequests: true,
73+
url: "[email protected]:argoproj/applicationset.git",
74+
branches: []string{"pr-1", "pr-2"},
75+
},
7076
}
7177

7278
for _, c := range cases {
7379
t.Run(c.name, func(t *testing.T) {
74-
provider, _ := NewGithubProvider(context.Background(), "argoproj", "", "", c.allBranches)
80+
provider, _ := NewGithubProvider(context.Background(), "argoproj", "", "", c.allBranches, c.allPullRequests)
7581
rawRepos, err := ListRepos(context.Background(), provider, c.filters, c.proto)
7682
if c.hasError {
7783
assert.Error(t, err)
@@ -98,7 +104,7 @@ func TestGithubListRepos(t *testing.T) {
98104
}
99105

100106
func TestGithubHasPath(t *testing.T) {
101-
host, _ := NewGithubProvider(context.Background(), "argoproj", "", "", false)
107+
host, _ := NewGithubProvider(context.Background(), "argoproj", "", "", false, false)
102108
repo := &Repository{
103109
Organization: "argoproj",
104110
Repository: "applicationset",

pkg/services/scm_provider/gitlab.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ type GitlabProvider struct {
1313
organization string
1414
allBranches bool
1515
includeSubgroups bool
16+
allPullRequests bool
1617
}
1718

1819
var _ SCMProviderService = &GitlabProvider{}
1920

20-
func NewGitlabProvider(ctx context.Context, organization string, token string, url string, allBranches, includeSubgroups bool) (*GitlabProvider, error) {
21+
func NewGitlabProvider(ctx context.Context, organization string, token string, url string, allBranches, includeSubgroups, allPullRequests bool) (*GitlabProvider, error) {
2122
// Undocumented environment variable to set a default token, to be used in testing to dodge anonymous rate limits.
2223
if token == "" {
2324
token = os.Getenv("GITLAB_TOKEN")
@@ -36,7 +37,7 @@ func NewGitlabProvider(ctx context.Context, organization string, token string, u
3637
return nil, err
3738
}
3839
}
39-
return &GitlabProvider{client: client, organization: organization, allBranches: allBranches, includeSubgroups: includeSubgroups}, nil
40+
return &GitlabProvider{client: client, organization: organization, allBranches: allBranches, includeSubgroups: includeSubgroups, allPullRequests: allPullRequests}, nil
4041
}
4142

4243
func (g *GitlabProvider) GetBranches(ctx context.Context, repo *Repository) ([]*Repository, error) {
@@ -60,6 +61,28 @@ func (g *GitlabProvider) GetBranches(ctx context.Context, repo *Repository) ([]*
6061
return repos, nil
6162
}
6263

64+
func (g *GitlabProvider) GetPullRequests(ctx context.Context, repo *Repository) ([]*Repository, error) {
65+
repos := []*Repository{}
66+
67+
pullRequests, err := g.listPullRequests(ctx, repo)
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
for _, pullRequest := range pullRequests {
73+
repos = append(repos, &Repository{
74+
Organization: repo.Organization,
75+
Repository: repo.Repository,
76+
URL: repo.URL,
77+
Branch: pullRequest.Title,
78+
SHA: pullRequest.SHA,
79+
Labels: pullRequest.Labels,
80+
RepositoryId: repo.RepositoryId,
81+
})
82+
}
83+
return repos, nil
84+
}
85+
6386
func (g *GitlabProvider) ListRepos(ctx context.Context, cloneProtocol string) ([]*Repository, error) {
6487
opt := &gitlab.ListGroupProjectsOptions{
6588
ListOptions: gitlab.ListOptions{PerPage: 100},
@@ -149,3 +172,26 @@ func (g *GitlabProvider) listBranches(_ context.Context, repo *Repository) ([]gi
149172
}
150173
return branches, nil
151174
}
175+
176+
func (g *GitlabProvider) listPullRequests(_ context.Context, repo *Repository) ([]gitlab.MergeRequest, error) {
177+
opt := &gitlab.ListProjectMergeRequestsOptions{
178+
ListOptions: gitlab.ListOptions{PerPage: 100},
179+
}
180+
181+
pullRequests := []gitlab.MergeRequest{}
182+
for {
183+
gitlabPullRequests, resp, err := g.client.MergeRequests.ListProjectMergeRequests(repo.RepositoryId, opt)
184+
if err != nil {
185+
return nil, err
186+
}
187+
for _, gitlabPullRequest := range gitlabPullRequests {
188+
pullRequests = append(pullRequests, *gitlabPullRequest)
189+
}
190+
191+
if resp.NextPage == 0 {
192+
break
193+
}
194+
opt.Page = resp.NextPage
195+
}
196+
return pullRequests, nil
197+
}

pkg/services/scm_provider/gitlab_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import (
1010

1111
func TestGitlabListRepos(t *testing.T) {
1212
cases := []struct {
13-
name, proto, url string
14-
hasError, allBranches, includeSubgroups bool
15-
branches []string
16-
filters []v1alpha1.SCMProviderGeneratorFilter
13+
name, proto, url string
14+
hasError, allBranches, includeSubgroups, allPullRequests bool
15+
branches []string
16+
filters []v1alpha1.SCMProviderGeneratorFilter
1717
}{
1818
{
1919
name: "blank protocol",
@@ -45,7 +45,7 @@ func TestGitlabListRepos(t *testing.T) {
4545

4646
for _, c := range cases {
4747
t.Run(c.name, func(t *testing.T) {
48-
provider, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", c.allBranches, c.includeSubgroups)
48+
provider, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", c.allBranches, c.includeSubgroups, c.allPullRequests)
4949
rawRepos, err := ListRepos(context.Background(), provider, c.filters, c.proto)
5050
if c.hasError {
5151
assert.NotNil(t, err)
@@ -72,7 +72,7 @@ func TestGitlabListRepos(t *testing.T) {
7272
}
7373

7474
func TestGitlabHasPath(t *testing.T) {
75-
host, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", false, true)
75+
host, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", false, true, false)
7676
repo := &Repository{
7777
Organization: "test-argocd-proton",
7878
Repository: "argocd",

pkg/services/scm_provider/mock.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,25 @@ func (m *MockProvider) GetBranches(_ context.Context, repo *Repository) ([]*Repo
4444
branchRepos = append(branchRepos, candidateRepo)
4545
}
4646
}
47-
4847
}
4948
return branchRepos, nil
5049
}
50+
51+
func (m *MockProvider) GetPullRequests(_ context.Context, repo *Repository) ([]*Repository, error) {
52+
pullRequestRepos := []*Repository{}
53+
for _, candidateRepo := range m.Repos {
54+
if candidateRepo.Repository == repo.Repository {
55+
found := false
56+
for _, alreadySetRepo := range pullRequestRepos {
57+
if alreadySetRepo.Branch == candidateRepo.Branch {
58+
found = true
59+
break
60+
}
61+
}
62+
if !found {
63+
pullRequestRepos = append(pullRequestRepos, candidateRepo)
64+
}
65+
}
66+
}
67+
return pullRequestRepos, nil
68+
}

pkg/services/scm_provider/types.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ type SCMProviderService interface {
2020
ListRepos(context.Context, string) ([]*Repository, error)
2121
RepoHasPath(context.Context, *Repository, string) (bool, error)
2222
GetBranches(context.Context, *Repository) ([]*Repository, error)
23+
GetPullRequests(context.Context, *Repository) ([]*Repository, error)
2324
}
2425

2526
// A compiled version of SCMProviderGeneratorFilter for performance.
2627
type Filter struct {
27-
RepositoryMatch *regexp.Regexp
28-
PathsExist []string
29-
LabelMatch *regexp.Regexp
30-
BranchMatch *regexp.Regexp
31-
FilterType FilterType
28+
RepositoryMatch *regexp.Regexp
29+
PathsExist []string
30+
LabelMatch *regexp.Regexp
31+
BranchMatch *regexp.Regexp
32+
PullRequestTitleMatch *regexp.Regexp
33+
PullRequestLabelMatch *regexp.Regexp
34+
FilterType FilterType
3235
}
3336

3437
// A convenience type for indicating where to apply a filter
@@ -39,4 +42,5 @@ const (
3942
FilterTypeUndefined FilterType = iota
4043
FilterTypeBranch
4144
FilterTypeRepo
45+
FilterTypePullRequest
4246
)

0 commit comments

Comments
 (0)