66 " context"
77 {{- if ne .GoV2Package " " }}
88 " errors"
9+ " reflect"
910 {{- end }}
1011 " fmt"
1112 " maps"
@@ -14,29 +15,29 @@ import (
1415 {{- end }}
1516 " os"
1617 " path/filepath"
17- " reflect"
1818 " strings"
1919 " testing"
2020
2121 {{ if ne .GoV1Package " " }}
22- {{ if .ImportAWS_V1 }}
2322 aws_sdkv1 " github.com/aws/aws-sdk-go/aws"
24- {{ end - }}
25- {{ if eq .GoV2Package " " }}" github.com/aws/aws-sdk-go/aws/endpoints" {{ end }}
23+ {{- if eq .GoV2Package " " }}
24+ " github.com/aws/aws-sdk-go/aws/endpoints"
25+ {{- end }}
2626 {{ .GoV1Package }}_sdkv1 " github.com/aws/aws-sdk-go/service/{{ .GoV1Package }}"
2727 {{- end }}
2828 {{- if ne .V1AlternateInputPackage " " }}
2929 {{ .V1AlternateInputPackage }}_sdkv1 " github.com/aws/aws-sdk-go/service/{{ .V1AlternateInputPackage }}"
3030 {{- end - }}
3131 {{- if ne .GoV2Package " " }}
3232 aws_sdkv2 " github.com/aws/aws-sdk-go-v2/aws"
33+ awsmiddleware " github.com/aws/aws-sdk-go-v2/aws/middleware"
3334 {{ .GoV2Package }}_sdkv2 " github.com/aws/aws-sdk-go-v2/service/{{ .GoV2Package }}"
3435 {{- if .ImportAwsTypes }}
3536 awstypes " github.com/aws/aws-sdk-go-v2/service/{{ .GoV2Package }}/types"
3637 {{- end }}
37- {{- end }}
3838 " github.com/aws/smithy-go/middleware"
3939 smithyhttp " github.com/aws/smithy-go/transport/http"
40+ {{- end }}
4041 " github.com/google/go-cmp/cmp"
4142 " github.com/hashicorp/aws-sdk-go-base/v2/servicemocks"
4243 {{- if gt (len .Aliases ) 0 }}
@@ -70,11 +71,17 @@ type configFile struct {
7071type caseExpectations struct {
7172 diags diag.Diagnostics
7273 endpoint string
74+ region string
75+ }
76+
77+ type apiCallParams struct {
78+ endpoint string
79+ region string
7380}
7481
7582type setupFunc func(setup * caseSetup)
7683
77- type callFunc func(ctx context.Context , t * testing. T, meta * conns.AWSClient ) string
84+ type callFunc func(ctx context.Context , t * testing. T, meta * conns.AWSClient ) apiCallParams
7885
7986const (
8087 packageNameConfigEndpoint = " https://packagename-config.endpoint.test/"
@@ -109,13 +116,24 @@ const (
109116 {{ end }}
110117)
111118
119+ const (
120+ expectedCallRegion = {{ if .OverrideRegion }}" {{ .OverrideRegion }}" {{ else }}" {{ .Region }}" {{ end }} // lintignore:AWSAT003
121+ )
122+
112123func TestEndpointConfiguration (t * testing. T) { // nolint:paralleltest // uses t.Setenv
113- const region = " {{ .Region }}" // lintignore:AWSAT003
124+ const providerRegion = " {{ .Region }}" // lintignore:AWSAT003
125+ {{ if .OverrideRegionRegionalEndpoint - }}
126+ // {{ .HumanFriendly }} uses a regional endpoint but is only available in one region or a limited number of regions.
127+ // The provider overrides the region for {{ .HumanFriendly }}, but the AWS SDK's endpoint resolution returns one for the current region.
128+ const expectedEndpointRegion = " {{ .OverrideRegion }}" // lintignore:AWSAT003
129+ {{ else - }}
130+ const expectedEndpointRegion = providerRegion
131+ {{ end }}
114132
115133 testcases := map[string]endpointTestCase{
116134 " no config" : {
117135 with : []setupFunc{withNoConfig},
118- expected : expectDefaultEndpoint(region ),
136+ expected : expectDefaultEndpoint(expectedEndpointRegion ),
119137 },
120138
121139 // Package name endpoint on Config
@@ -456,7 +474,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
456474 with : []setupFunc{
457475 withUseFIPSInConfig,
458476 },
459- expected : expectDefaultFIPSEndpoint(region ),
477+ expected : expectDefaultFIPSEndpoint(expectedEndpointRegion ),
460478 },
461479
462480 " use fips config with package name endpoint config" : {
@@ -474,7 +492,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
474492 testcase := testcase
475493
476494 t.Run (name, func(t * testing. T) {
477- testEndpointCase(t, region , testcase, callServiceV1)
495+ testEndpointCase(t, providerRegion , testcase, callServiceV1)
478496 })
479497 }
480498 })
@@ -484,7 +502,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
484502 testcase := testcase
485503
486504 t.Run (name, func(t * testing. T) {
487- testEndpointCase(t, region , testcase, callServiceV2)
505+ testEndpointCase(t, providerRegion , testcase, callServiceV2)
488506 })
489507 }
490508 })
@@ -493,7 +511,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
493511 testcase := testcase
494512
495513 t.Run (name, func(t * testing. T) {
496- testEndpointCase(t, region , testcase, callService)
514+ testEndpointCase(t, providerRegion , testcase, callService)
497515 })
498516 }
499517 {{ end - }}
@@ -577,19 +595,20 @@ func defaultFIPSEndpoint(region string) string {
577595}
578596
579597{{ if ne .GoV2Package " " }}
580- func callService{{ if ne .GoV1Package " " }}V2 {{ end }}(ctx context.Context , t * testing. T, meta * conns.AWSClient ) string {
598+ func callService{{ if ne .GoV1Package " " }}V2 {{ end }}(ctx context.Context , t * testing. T, meta * conns.AWSClient ) apiCallParams {
581599 t.Helper ()
582600
583- var endpoint string
584-
585601 client := meta. {{ .ProviderNameUpper }}Client (ctx)
586602
603+ var result apiCallParams
604+
587605 _, err := client. {{ .APICall }}(ctx, &{{ .GoV2Package }}_sdkv2. {{ .APICall }}Input {
588606 {{ if ne .APICallParams " " }}{{ .APICallParams }},{{ end }}
589607 },
590608 func(opts * {{ .GoV2Package }}_sdkv2. Options ) {
591609 opts.APIOptions = append(opts.APIOptions ,
592- addRetrieveEndpointURLMiddleware(t, &endpoint),
610+ addRetrieveEndpointURLMiddleware(t, &result. endpoint),
611+ addRetrieveRegionMiddleware(&result. region),
593612 addCancelRequestMiddleware(),
594613 )
595614 },
@@ -600,12 +619,12 @@ func callService{{ if ne .GoV1Package "" }}V2{{ end }}(ctx context.Context, t *t
600619 t.Fatalf (" Unexpected error: %s" , err)
601620 }
602621
603- return endpoint
622+ return result
604623}
605624{{ end }}
606625
607626{{ if ne .GoV1Package " " }}
608- func callService{{ if ne .GoV2Package " " }}V1 {{ end }}(ctx context.Context , t * testing. T, meta * conns.AWSClient ) string {
627+ func callService{{ if ne .GoV2Package " " }}V1 {{ end }}(ctx context.Context , t * testing. T, meta * conns.AWSClient ) apiCallParams {
609628 t.Helper ()
610629
611630 client := meta. {{ .ProviderNameUpper }}Conn (ctx)
@@ -619,9 +638,10 @@ func callService{{ if ne .GoV2Package "" }}V1{{ end }}(ctx context.Context, t *t
619638
620639 req.HTTPRequest.URL.Path = " /"
621640
622- endpoint := req.HTTPRequest.URL.String ()
623-
624- return endpoint
641+ return apiCallParams{
642+ endpoint : req.HTTPRequest.URL.String (),
643+ region : aws_sdkv1.StringValue (client.Config.Region ),
644+ }
625645}
626646{{ end }}
627647
@@ -699,38 +719,44 @@ func withUseFIPSInConfig(setup *caseSetup) {
699719func expectDefaultEndpoint(region string) caseExpectations {
700720 return caseExpectations{
701721 endpoint : defaultEndpoint(region),
722+ region : expectedCallRegion,
702723 }
703724}
704725
705726func expectDefaultFIPSEndpoint(region string) caseExpectations {
706727 return caseExpectations{
707728 endpoint : defaultFIPSEndpoint(region),
729+ region : expectedCallRegion,
708730 }
709731}
710732
711733func expectPackageNameConfigEndpoint() caseExpectations {
712734 return caseExpectations{
713735 endpoint : packageNameConfigEndpoint,
736+ region : expectedCallRegion,
714737 }
715738}
716739
717740{{ range $i, $alias := .Aliases }}
718741func expectAliasName{{ $i }}ConfigEndpoint () caseExpectations {
719742 return caseExpectations{
720743 endpoint : aliasName{{ $i }}ConfigEndpoint ,
744+ region : expectedCallRegion,
721745 }
722746}
723747{{ end }}
724748
725749func expectAwsEnvVarEndpoint() caseExpectations {
726750 return caseExpectations{
727751 endpoint : awsServiceEnvvarEndpoint,
752+ region : expectedCallRegion,
728753 }
729754}
730755
731756func expectBaseEnvVarEndpoint() caseExpectations {
732757 return caseExpectations{
733758 endpoint : baseEnvvarEndpoint,
759+ region : expectedCallRegion,
734760 }
735761}
736762
@@ -741,6 +767,7 @@ func expectTfAwsEnvVarEndpoint() caseExpectations {
741767 diags : diag.Diagnostics {
742768 provider.DeprecatedEnvVarDiag (tfAwsEnvVar, awsEnvVar),
743769 },
770+ region : expectedCallRegion,
744771 }
745772}
746773{{ end }}
@@ -752,19 +779,22 @@ func expectDeprecatedEnvVarEndpoint() caseExpectations {
752779 diags : diag.Diagnostics {
753780 provider.DeprecatedEnvVarDiag (deprecatedEnvVar, awsEnvVar),
754781 },
782+ region : expectedCallRegion,
755783 }
756784}
757785{{ end }}
758786
759787func expectServiceConfigFileEndpoint() caseExpectations {
760788 return caseExpectations{
761789 endpoint : serviceConfigFileEndpoint,
790+ region : expectedCallRegion,
762791 }
763792}
764793
765794func expectBaseConfigFileEndpoint() caseExpectations {
766795 return caseExpectations{
767796 endpoint : baseConfigFileEndpoint,
797+ region : expectedCallRegion,
768798 }
769799}
770800
@@ -828,13 +858,18 @@ func testEndpointCase(t *testing.T, region string, testcase endpointTestCase, ca
828858
829859 meta := p.Meta (). (* conns.AWSClient )
830860
831- endpoint := callF(ctx, t, meta)
861+ callParams := callF(ctx, t, meta)
862+
863+ if e, a := testcase. expected. endpoint, callParams. endpoint; e != a {
864+ t.Errorf (" expected endpoint %q, got %q" , e, a)
865+ }
832866
833- if endpoint ! = testcase. expected. endpoint {
834- t.Errorf (" expected endpoint %q, got %q" , testcase . expected . endpoint, endpoint )
867+ if e, a : = testcase. expected. region, callParams . region; e != a {
868+ t.Errorf (" expected region %q, got %q" , e, a )
835869 }
836870}
837871
872+ {{ if ne .GoV2Package " " }}
838873func addRetrieveEndpointURLMiddleware(t * testing. T, endpoint * string) func(* middleware.Stack ) error {
839874 return func(stack * middleware.Stack ) error {
840875 return stack.Finalize.Add (
@@ -865,6 +900,26 @@ func retrieveEndpointURLMiddleware(t *testing.T, endpoint *string) middleware.Fi
865900 })
866901}
867902
903+ func addRetrieveRegionMiddleware(region * string) func(* middleware.Stack ) error {
904+ return func(stack * middleware.Stack ) error {
905+ return stack.Serialize.Add (
906+ retrieveRegionMiddleware(region),
907+ middleware.After ,
908+ )
909+ }
910+ }
911+
912+ func retrieveRegionMiddleware(region * string) middleware.SerializeMiddleware {
913+ return middleware.SerializeMiddlewareFunc (
914+ " Test: Retrieve Region" ,
915+ func(ctx context.Context , in middleware.SerializeInput , next middleware.SerializeHandler ) (middleware.SerializeOutput , middleware.Metadata , error) {
916+ * region = awsmiddleware.GetRegion (ctx)
917+
918+ return next.HandleSerialize (ctx, in )
919+ },
920+ )
921+ }
922+
868923var errCancelOperation = fmt.Errorf (" Test: Canceling request" )
869924
870925func addCancelRequestMiddleware() func(* middleware.Stack ) error {
@@ -897,6 +952,7 @@ func fullValueTypeName(v reflect.Value) string {
897952 requestType := v.Type ()
898953 return fmt.Sprintf (" %s.%s" , requestType.PkgPath (), requestType.Name ())
899954}
955+ {{ end }}
900956
901957func generateSharedConfigFile(config configFile) string {
902958 var buf strings.Builder
0 commit comments