From f2d9f711921381d317409c6f7e58cacde4a0b506 Mon Sep 17 00:00:00 2001 From: Richard Gomez Date: Thu, 8 Aug 2024 10:56:58 -0400 Subject: [PATCH] fix(github): use scm-base-url for clients --- analyze/analyze.go | 6 +++--- providers/github/client.go | 40 +++++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/analyze/analyze.go b/analyze/analyze.go index d4da3d06..0f1ca387 100644 --- a/analyze/analyze.go +++ b/analyze/analyze.go @@ -80,7 +80,7 @@ func (a *Analyzer) AnalyzeOrg(ctx context.Context, org string, numberOfGoroutine log.Debug().Err(err).Msgf("Failed to get provider version for %s", provider) } - log.Debug().Msgf("Provider: %s, Version: %s", provider, providerVersion) + log.Debug().Msgf("Provider: %s, Version: %s, BaseURL: %s", provider, providerVersion, a.ScmClient.GetProviderBaseURL()) log.Debug().Msgf("Fetching list of repositories for organization: %s on %s", org, provider) orgReposBatches := a.ScmClient.GetOrgRepos(ctx, org) @@ -177,7 +177,7 @@ func (a *Analyzer) AnalyzeRepo(ctx context.Context, repoString string, ref strin log.Debug().Err(err).Msgf("Failed to get provider version for %s", provider) } - log.Debug().Msgf("Provider: %s, Version: %s", provider, providerVersion) + log.Debug().Msgf("Provider: %s, Version: %s, BaseURL: %s", provider, providerVersion, a.ScmClient.GetProviderBaseURL()) pkgsupplyClient := pkgsupply.NewStaticClient() inventory := scanner.NewInventory(a.Opa, pkgsupplyClient, provider, providerVersion) @@ -225,7 +225,7 @@ func (a *Analyzer) AnalyzeLocalRepo(ctx context.Context, repoPath string) error log.Debug().Err(err).Msgf("Failed to get provider version for %s", provider) } - log.Debug().Msgf("Provider: %s, Version: %s", provider, providerVersion) + log.Debug().Msgf("Provider: %s, Version: %s, BaseURL: %s", provider, providerVersion, a.ScmClient.GetProviderBaseURL()) pkgsupplyClient := pkgsupply.NewStaticClient() inventory := scanner.NewInventory(a.Opa, pkgsupplyClient, provider, providerVersion) diff --git a/providers/github/client.go b/providers/github/client.go index 29b5df2a..f597fb09 100644 --- a/providers/github/client.go +++ b/providers/github/client.go @@ -19,17 +19,19 @@ import ( ) const GitHub string = "github" +const defaultDomain string = "github.com" func NewGithubSCMClient(ctx context.Context, baseURL string, token string) (*ScmClient, error) { - client, err := NewClient(ctx, token) + domain := defaultDomain + if baseURL != "" { + domain = baseURL + } + + client, err := NewClient(ctx, token, domain) if err != nil { return nil, err } - domain := "github.com" - if baseURL != "" { - domain = baseURL - } return &ScmClient{ client: client, baseURL: domain, @@ -131,19 +133,35 @@ type Client struct { Token string } -func NewClient(ctx context.Context, token string) (*Client, error) { +func NewClient(ctx context.Context, token string, domain string) (*Client, error) { rateLimiter, err := github_ratelimit.NewRateLimitWaiterClient(nil) if err != nil { return nil, err } - restClient := github.NewClient(rateLimiter).WithAuthToken(token) - src := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: token}, + var ( + // REST client + restClient = github.NewClient(rateLimiter).WithAuthToken(token) + // GraphQL client + src = oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: token}, + ) + httpClient = oauth2.NewClient(ctx, src) + graphQLClient *githubv4.Client ) - httpClient := oauth2.NewClient(ctx, src) - graphQLClient := githubv4.NewClient(httpClient) + if domain == defaultDomain { + graphQLClient = githubv4.NewClient(httpClient) + } else { + baseURL := fmt.Sprintf("https://%s/", domain) + restClient, err = restClient.WithEnterpriseURLs(baseURL, baseURL) + if err != nil { + return nil, err + } + + graphQLClient = githubv4.NewEnterpriseClient(baseURL+"api/graphql", httpClient) + } + return &Client{ restClient: restClient, graphQLClient: graphQLClient,