diff --git a/pkg/fanal/image/registry/azure/azure.go b/pkg/fanal/image/registry/azure/azure.go index 491fce4e27c3..c57de39f21e5 100644 --- a/pkg/fanal/image/registry/azure/azure.go +++ b/pkg/fanal/image/registry/azure/azure.go @@ -10,30 +10,33 @@ import ( "github.com/Azure/azure-sdk-for-go/profiles/preview/preview/containerregistry/runtime/containerregistry" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/fanal/types" ) -type Registry struct { +type RegistryClient struct { domain string } +type Registry struct { +} + const ( azureURL = "azurecr.io" scope = "https://management.azure.com/.default" scheme = "https" ) -func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) error { +func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) (intf.RegistryClient, error) { if !strings.HasSuffix(domain, azureURL) { - return xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern) + return nil, xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern) } - r.domain = domain - return nil + return &RegistryClient{domain: domain}, nil } -func (r *Registry) GetCredential(ctx context.Context) (string, string, error) { +func (r *RegistryClient) GetCredential(ctx context.Context) (string, string, error) { cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return "", "", xerrors.Errorf("unable to generate acr credential error: %w", err) diff --git a/pkg/fanal/image/registry/azure/azure_test.go b/pkg/fanal/image/registry/azure/azure_test.go index ae823b82a65a..bc16d8b243b1 100644 --- a/pkg/fanal/image/registry/azure/azure_test.go +++ b/pkg/fanal/image/registry/azure/azure_test.go @@ -28,7 +28,7 @@ func TestRegistry_CheckOptions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := azure.Registry{} - err := r.CheckOptions(tt.domain, types.RegistryOptions{}) + _, err := r.CheckOptions(tt.domain, types.RegistryOptions{}) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) } else { diff --git a/pkg/fanal/image/registry/ecr/ecr.go b/pkg/fanal/image/registry/ecr/ecr.go index e675ed47afaf..bfc49185b8e7 100644 --- a/pkg/fanal/image/registry/ecr/ecr.go +++ b/pkg/fanal/image/registry/ecr/ecr.go @@ -3,8 +3,10 @@ package ecr import ( "context" "encoding/base64" + "regexp" "strings" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" @@ -12,46 +14,71 @@ import ( "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/log" ) -const ecrURL = "amazonaws.com" - type ecrAPI interface { GetAuthorizationToken(ctx context.Context, params *ecr.GetAuthorizationTokenInput, optFns ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error) } type ECR struct { +} + +type ECRClient struct { Client ecrAPI } -func getSession(option types.RegistryOptions) (aws.Config, error) { +func getSession(domain, region string, option types.RegistryOptions) (aws.Config, error) { // create custom credential information if option is valid if option.AWSSecretKey != "" && option.AWSAccessKey != "" && option.AWSRegion != "" { + if region != option.AWSRegion { + log.Logger.Warnf("The region from AWS_REGION (%s) is being overridden. The region from domain (%s) was used.", option.AWSRegion, domain) + } return config.LoadDefaultConfig( context.TODO(), - config.WithRegion(option.AWSRegion), + config.WithRegion(region), config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(option.AWSAccessKey, option.AWSSecretKey, option.AWSSessionToken)), ) } - return config.LoadDefaultConfig(context.TODO()) + return config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) } -func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) error { - if !strings.HasSuffix(domain, ecrURL) { - return xerrors.Errorf("ECR : %w", types.InvalidURLPattern) +func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) { + region := determineRegion(domain) + if region == "" { + return nil, xerrors.Errorf("ECR : %w", types.InvalidURLPattern) } - cfg, err := getSession(option) + cfg, err := getSession(domain, region, option) if err != nil { - return err + return nil, err } svc := ecr.NewFromConfig(cfg) - e.Client = svc - return nil + return &ECRClient{Client: svc}, nil +} + +// Endpoints take the form +// .dkr.ecr..amazonaws.com +// .dkr.ecr-fips..amazonaws.com +// .dkr.ecr..amazonaws.com.cn +// .dkr.ecr..sc2s.sgov.gov +// .dkr.ecr..c2s.ic.gov +// see +// - https://docs.aws.amazon.com/general/latest/gr/ecr.html +// - https://docs.amazonaws.cn/en_us/aws/latest/userguide/endpoints-arns.html +// - https://github.com/boto/botocore/blob/1.34.51/botocore/data/endpoints.json +var ecrEndpointMatch = regexp.MustCompile(`^[^.]+\.dkr\.ecr(?:-fips)?\.([^.]+)\.(?:amazonaws\.com(?:\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov)$`) + +func determineRegion(domain string) string { + matches := ecrEndpointMatch.FindStringSubmatch(domain) + if matches != nil { + return matches[1] + } + return "" } -func (e *ECR) GetCredential(ctx context.Context) (username, password string, err error) { +func (e *ECRClient) GetCredential(ctx context.Context) (username, password string, err error) { input := &ecr.GetAuthorizationTokenInput{} result, err := e.Client.GetAuthorizationToken(ctx, input) if err != nil { diff --git a/pkg/fanal/image/registry/ecr/ecr_test.go b/pkg/fanal/image/registry/ecr/ecr_test.go index 63ae1858114c..3544c29f0baf 100644 --- a/pkg/fanal/image/registry/ecr/ecr_test.go +++ b/pkg/fanal/image/registry/ecr/ecr_test.go @@ -8,33 +8,91 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecr" awstypes "github.com/aws/aws-sdk-go-v2/service/ecr/types" + "github.com/stretchr/testify/require" "github.com/aquasecurity/trivy/pkg/fanal/types" ) +type testECRClient interface { + Options() ecr.Options +} + func TestCheckOptions(t *testing.T) { var tests = map[string]struct { - domain string - wantErr error + domain string + expectedRegion string + wantErr error }{ "InvalidURL": { domain: "alpine:3.9", wantErr: types.InvalidURLPattern, }, "NoOption": { - domain: "xxx.ecr.ap-northeast-1.amazonaws.com", + domain: "xxx.dkr.ecr.ap-northeast-1.amazonaws.com", + expectedRegion: "ap-northeast-1", + }, + "region-1": { + domain: "xxx.dkr.ecr.region-1.amazonaws.com", + expectedRegion: "region-1", + }, + "region-2": { + domain: "xxx.dkr.ecr.region-2.amazonaws.com", + expectedRegion: "region-2", + }, + "fips-region-1": { + domain: "xxx.dkr.ecr-fips.fips-region.amazonaws.com", + expectedRegion: "fips-region", + }, + "cn-region-1": { + domain: "xxx.dkr.ecr.region-1.amazonaws.com.cn", + expectedRegion: "region-1", + }, + "cn-region-2": { + domain: "xxx.dkr.ecr.region-2.amazonaws.com.cn", + expectedRegion: "region-2", + }, + "sc2s-region-1": { + domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov.gov", + expectedRegion: "sc2s-region", + }, + "c2s-region-1": { + domain: "xxx.dkr.ecr.c2s-region.c2s.ic.gov", + expectedRegion: "c2s-region", + }, + "invalid-ecr": { + domain: "xxx.dkrecr.region-1.amazonaws.com", + wantErr: types.InvalidURLPattern, + }, + "invalid-fips": { + domain: "xxx.dkr.ecrfips.fips-region.amazonaws.com", + wantErr: types.InvalidURLPattern, + }, + "invalid-cn": { + domain: "xxx.dkr.ecr.region-2.amazonaws.cn", + wantErr: types.InvalidURLPattern, + }, + "invalid-sc2s": { + domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov", + wantErr: types.InvalidURLPattern, + }, + "invalid-cs2": { + domain: "xxx.dkr.ecr.c2s-region.c2s.ic", + wantErr: types.InvalidURLPattern, }, } for testname, v := range tests { a := &ECR{} - err := a.CheckOptions(v.domain, types.RegistryOptions{}) + ecrClient, err := a.CheckOptions(v.domain, types.RegistryOptions{}) if err != nil { if !errors.Is(err, v.wantErr) { t.Errorf("[%s]\nexpected error based on %v\nactual : %v", testname, v.wantErr, err) } continue } + + client := (ecrClient.(*ECRClient)).Client.(testECRClient) + require.Equal(t, v.expectedRegion, client.Options().Region) } } @@ -82,7 +140,7 @@ func TestECRGetCredential(t *testing.T) { } for i, c := range cases { - e := ECR{ + e := ECRClient{ Client: mockedECR{Resp: c.Resp}, } username, password, err := e.GetCredential(context.Background()) diff --git a/pkg/fanal/image/registry/google/google.go b/pkg/fanal/image/registry/google/google.go index f4e7a7414260..eea27f39da06 100644 --- a/pkg/fanal/image/registry/google/google.go +++ b/pkg/fanal/image/registry/google/google.go @@ -7,34 +7,38 @@ import ( "github.com/GoogleCloudPlatform/docker-credential-gcr/config" "github.com/GoogleCloudPlatform/docker-credential-gcr/credhelper" "github.com/GoogleCloudPlatform/docker-credential-gcr/store" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/fanal/types" ) -type Registry struct { +type GoogleRegistryClient struct { Store store.GCRCredStore domain string } +type Registry struct { +} + // Google container registry const gcrURL = "gcr.io" // Google artifact registry const garURL = "docker.pkg.dev" -func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) error { +func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) { if !strings.HasSuffix(domain, gcrURL) && !strings.HasSuffix(domain, garURL) { - return xerrors.Errorf("Google registry: %w", types.InvalidURLPattern) + return nil, xerrors.Errorf("Google registry: %w", types.InvalidURLPattern) } - g.domain = domain + client := GoogleRegistryClient{domain: domain} if option.GCPCredPath != "" { - g.Store = store.NewGCRCredStore(option.GCPCredPath) + client.Store = store.NewGCRCredStore(option.GCPCredPath) } - return nil + return &client, nil } -func (g *Registry) GetCredential(_ context.Context) (username, password string, err error) { +func (g *GoogleRegistryClient) GetCredential(_ context.Context) (username, password string, err error) { var credStore store.GCRCredStore if g.Store == nil { credStore, err = store.DefaultGCRCredStore() diff --git a/pkg/fanal/image/registry/google/google_test.go b/pkg/fanal/image/registry/google/google_test.go index 62a2b1f57627..e2b632bc83c7 100644 --- a/pkg/fanal/image/registry/google/google_test.go +++ b/pkg/fanal/image/registry/google/google_test.go @@ -5,8 +5,6 @@ import ( "reflect" "testing" - "github.com/GoogleCloudPlatform/docker-credential-gcr/store" - "github.com/aquasecurity/trivy/pkg/fanal/types" ) @@ -23,21 +21,18 @@ func TestCheckOptions(t *testing.T) { }, "NoOption": { domain: "gcr.io", - gcr: &Registry{domain: "gcr.io"}, + gcr: &Registry{}, }, "CredOption": { domain: "gcr.io", opt: types.RegistryOptions{GCPCredPath: "/path/to/file.json"}, - gcr: &Registry{ - domain: "gcr.io", - Store: store.NewGCRCredStore("/path/to/file.json"), - }, + gcr: &Registry{}, }, } for testname, v := range tests { g := &Registry{} - err := g.CheckOptions(v.domain, v.opt) + _, err := g.CheckOptions(v.domain, v.opt) if v.wantErr != nil { if err == nil { t.Errorf("%s : expected error but no error", testname) diff --git a/pkg/fanal/image/registry/intf/registry.go b/pkg/fanal/image/registry/intf/registry.go new file mode 100644 index 000000000000..da8c5d3c1789 --- /dev/null +++ b/pkg/fanal/image/registry/intf/registry.go @@ -0,0 +1,15 @@ +package intf + +import ( + "context" + + "github.com/aquasecurity/trivy/pkg/fanal/types" +) + +type RegistryClient interface { + GetCredential(ctx context.Context) (string, string, error) +} + +type Registry interface { + CheckOptions(domain string, option types.RegistryOptions) (RegistryClient, error) +} diff --git a/pkg/fanal/image/registry/token.go b/pkg/fanal/image/registry/token.go index 1e51b0fd31a4..de6b4f426adc 100644 --- a/pkg/fanal/image/registry/token.go +++ b/pkg/fanal/image/registry/token.go @@ -8,12 +8,13 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/image/registry/azure" "github.com/aquasecurity/trivy/pkg/fanal/image/registry/ecr" "github.com/aquasecurity/trivy/pkg/fanal/image/registry/google" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "github.com/aquasecurity/trivy/pkg/fanal/log" "github.com/aquasecurity/trivy/pkg/fanal/types" ) var ( - registries []Registry + registries []intf.Registry ) func init() { @@ -22,23 +23,18 @@ func init() { RegisterRegistry(&azure.Registry{}) } -type Registry interface { - CheckOptions(domain string, option types.RegistryOptions) error - GetCredential(ctx context.Context) (string, string, error) -} - -func RegisterRegistry(registry Registry) { +func RegisterRegistry(registry intf.Registry) { registries = append(registries, registry) } func GetToken(ctx context.Context, domain string, opt types.RegistryOptions) (auth authn.Basic) { // check registry which particular to get credential for _, registry := range registries { - err := registry.CheckOptions(domain, opt) + client, err := registry.CheckOptions(domain, opt) if err != nil { continue } - username, password, err := registry.GetCredential(ctx) + username, password, err := client.GetCredential(ctx) if err != nil { // only skip check registry if error occurred log.Logger.Debug(err)