From 612b147fe76c39acb39bc56dd285fc0565ee74a3 Mon Sep 17 00:00:00 2001 From: Michael Savigny Date: Fri, 21 Jun 2024 18:39:07 +0000 Subject: [PATCH] Revert "Remove duplicate credshelper code in re-client" This reverts commit 51f4875706f1b0ada37f96a4fb7e3ce6bbf38246. Reason for revert: Causes auth failure on windows when using luci-auth Bug: na Test: failure on windows with luci-auth Change-Id: I1b701eceaf0f2bd82744ee38539a7e3de726cd29 GitOrigin-RevId: 102018bc1e69e0ce3a2c3be5a688dbb8c62b43c5 --- MODULE.bazel | 1 + cmd/bootstrap/main.go | 39 +++-- cmd/reproxy/main.go | 54 ++++-- go.mod | 2 +- internal/pkg/auth/BUILD.bazel | 17 +- internal/pkg/auth/auth.go | 308 ++++++++++++++++++++++++++++++++- internal/pkg/auth/auth_test.go | 264 ++++++++++++++++++++++++++++ internal/pkg/auth/cache.go | 160 +++++++++++++++++ 8 files changed, 812 insertions(+), 33 deletions(-) create mode 100644 internal/pkg/auth/cache.go diff --git a/MODULE.bazel b/MODULE.bazel index e1ee5e3f..79abbb34 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -192,6 +192,7 @@ use_repo( "com_github_googlecloudplatform_protoc_gen_bq_schema", "com_github_gorilla_mux", "com_github_gosuri_uilive", + "com_github_hectane_go_acl", "com_github_karrick_godirwalk", "com_github_microsoft_go_winio", "com_github_pkg_xattr", diff --git a/cmd/bootstrap/main.go b/cmd/bootstrap/main.go index 17972fc2..5bfae041 100644 --- a/cmd/bootstrap/main.go +++ b/cmd/bootstrap/main.go @@ -54,8 +54,10 @@ import ( var bootstrapStart = time.Now() var ( - homeDir, _ = os.UserHomeDir() - logDir = os.TempDir() + homeDir, _ = os.UserHomeDir() + gcertErrMsg = fmt.Sprintf("\nTry restarting the build after running %q\n", "gcert") + gcloudErrMsg = fmt.Sprintf("\nTry restarting the build after running %q\n", "gcloud auth login") + logDir = os.TempDir() ) var ( @@ -141,15 +143,9 @@ func main() { log.Exitf("Failed to determine the token cache file name: %v", err) } var chCreds *credshelper.Credentials + var creds *auth.Credentials var ts *grpcOauth.TokenSource - credsArgs := []string{} if !*remoteDisabled { - if *experimentalCredentialsHelper != "" && *credentialsHelper == "" { - *credentialsHelper = *experimentalCredentialsHelper - *credentialsHelperArgs = *experimentalCredentialsHelperArgs - credsArgs = append(credsArgs, fmt.Sprintf("--%v=%v", credshelper.CredshelperPathFlag, *credentialsHelper)) - credsArgs = append(credsArgs, fmt.Sprintf("--%v=%v", credshelper.CredshelperArgsFlag, *credentialsHelperArgs)) - } if *credentialsHelper != "" { c, err := credshelper.NewExternalCredentials(*credentialsHelper, strings.Fields(*credentialsHelperArgs), cf) if err != nil { @@ -164,12 +160,15 @@ func main() { chCreds = c ts = c.TokenSource() } else { - m := authMechanism() - status, err := auth.UpdateStatus(m) + c := newCreds(cf) + status, err := c.UpdateStatus() if err != nil { log.Errorf("Error obtaining credentials: %v", err) os.Exit(status) } + c.SaveToDisk() + creds = c + ts = c.TokenSource() } } @@ -255,7 +254,6 @@ func main() { currArgs := args[:] if *experimentalCredentialsHelper != "" || *credentialsHelper != "" { currArgs = append(currArgs, "--use_external_auth_token=true") - currArgs = append(currArgs, credsArgs...) } msg, exitCode := bootstrapReproxy(currArgs, bootstrapStart) if exitCode == 0 { @@ -263,6 +261,7 @@ func main() { } else { fmt.Fprintf(os.Stderr, "\nReproxy failed to start:%s\nCredentials cache file was deleted. Please try again. If this continues to fail, please file a bug.\n", msg) chCreds.RemoveFromDisk() + creds.RemoveFromDisk() } log.Flush() os.Exit(exitCode) @@ -371,17 +370,25 @@ func credsFilePath() (string, error) { return cf, nil } -func authMechanism() auth.Mechanism { +func newCreds(cf string) *auth.Credentials { if *experimentalCredentialsHelper != "" { - fmt.Fprintf(os.Stderr, "--experimental_credentials_helper flags are deprecated, please use --credentials_helper flags") - os.Exit(auth.ExitCodeExternalTokenAuth) + creds, err := auth.NewExternalCredentials(*experimentalCredentialsHelper, strings.Fields(*experimentalCredentialsHelperArgs), cf) + if err != nil { + fmt.Fprintf(os.Stderr, "Experimental credentials helper failed. Please try again or use application default credentials:%v", err) + os.Exit(auth.ExitCodeExternalTokenAuth) + } + return creds } m, err := auth.MechanismFromFlags() if err != nil || m == auth.Unknown { log.Errorf("Failed to determine auth mechanism: %v", err) os.Exit(auth.ExitCodeNoAuth) } - return m + c, err := auth.NewCredentials(m, cf) + if err != nil { + log.Exitf("Failed to initialize credentials: %v", err) + } + return c } func parseLogs() ([]*lpb.LogRecord, []*lpb.ProxyInfo) { diff --git a/cmd/reproxy/main.go b/cmd/reproxy/main.go index 548ecf9d..c0417f0e 100644 --- a/cmd/reproxy/main.go +++ b/cmd/reproxy/main.go @@ -107,17 +107,19 @@ var ( idleTimeout = flag.Duration("proxy_idle_timeout", 6*time.Hour, "Inactivity period after which the running reproxy process will be killed. Default is 6 hours. When set to 0, idle timeout is disabled.") depsCacheMaxMb = flag.Int("deps_cache_max_mb", 128, "Maximum size of the deps cache file (for goma input processor only).") // TODO(b/233275188): remove this flag. - _ = flag.Duration("ip_reset_min_delay", 3*time.Minute, "Deprecated. The minimum time after the input processor has been reset before it can be reset again. Negative values disable resetting.") - ipTimeout = flag.Duration("ip_timeout", 10*time.Minute, "The maximum time to wait for an input processor action. Zero and negative values disable timeout.") - metricsProject = flag.String("metrics_project", "", "If set, action and build metrics are exported to Cloud Monitoring in the specified GCP project") - metricsPrefix = flag.String("metrics_prefix", "", "Prefix of metrics exported to Cloud Monitoring") - metricsNamespace = flag.String("metrics_namespace", "", "Namespace of metrics exported to Cloud Monitoring (e.g. RBE project)") - failEarlyMinActionCount = flag.Int64("fail_early_min_action_count", 0, "Minimum number of actions received by reproxy before the fail early mechanism can take effect. 0 indicates fail early is disabled.") - failEarlyMinFallbackRatio = flag.Float64("fail_early_min_fallback_ratio", 0, "Minimum ratio of fallbacks to total actions above which the build terminates early. Ratio is a number in the range [0,1]. 0 indicates fail early is disabled.") - failEarlyWindow = flag.Duration("fail_early_window", 0, "Window of time to consider for fail_early_min_action_count and fail_early_min_fallback_ratio. 0 indicates all datapoints should be used.") - racingBias = flag.Float64("racing_bias", 0.75, "Value between [0,1] to indicate how racing manages the tradeoff of saving bandwidth (0) versus speed (1). The default is to prefer speed over bandwidth.") - racingTmp = flag.String("racing_tmp_dir", "", "DEPRECATED. Use download_tmp_dir instead.") - downloadTmp = flag.String("download_tmp_dir", "", "Directory where reproxy should store outputs temporarily before moving them to the desired location. This should be on the same device as the output directory for the build. The default is outputs will be written to a subdirectory inside the action's working directory. Note that the download_tmp_dir will only be used if the action has racing as its exec strategy or it explicitly sets EnableAtomicDownloads=true. See proxy.proto for details.") + _ = flag.Duration("ip_reset_min_delay", 3*time.Minute, "Deprecated. The minimum time after the input processor has been reset before it can be reset again. Negative values disable resetting.") + ipTimeout = flag.Duration("ip_timeout", 10*time.Minute, "The maximum time to wait for an input processor action. Zero and negative values disable timeout.") + metricsProject = flag.String("metrics_project", "", "If set, action and build metrics are exported to Cloud Monitoring in the specified GCP project") + metricsPrefix = flag.String("metrics_prefix", "", "Prefix of metrics exported to Cloud Monitoring") + metricsNamespace = flag.String("metrics_namespace", "", "Namespace of metrics exported to Cloud Monitoring (e.g. RBE project)") + experimentalCredentialsHelper = flag.String(auth.CredshelperPathFlag, "", "Path to the credentials helper binary. If given execrel://, looks for the `credshelper` binary in the same folder as reproxy") + experimentalCredentialsHelperArgs = flag.String(auth.CredshelperArgsFlag, "", "Arguments for the experimental credentials helper, separated by space.") + failEarlyMinActionCount = flag.Int64("fail_early_min_action_count", 0, "Minimum number of actions received by reproxy before the fail early mechanism can take effect. 0 indicates fail early is disabled.") + failEarlyMinFallbackRatio = flag.Float64("fail_early_min_fallback_ratio", 0, "Minimum ratio of fallbacks to total actions above which the build terminates early. Ratio is a number in the range [0,1]. 0 indicates fail early is disabled.") + failEarlyWindow = flag.Duration("fail_early_window", 0, "Window of time to consider for fail_early_min_action_count and fail_early_min_fallback_ratio. 0 indicates all datapoints should be used.") + racingBias = flag.Float64("racing_bias", 0.75, "Value between [0,1] to indicate how racing manages the tradeoff of saving bandwidth (0) versus speed (1). The default is to prefer speed over bandwidth.") + racingTmp = flag.String("racing_tmp_dir", "", "DEPRECATED. Use download_tmp_dir instead.") + downloadTmp = flag.String("download_tmp_dir", "", "Directory where reproxy should store outputs temporarily before moving them to the desired location. This should be on the same device as the output directory for the build. The default is outputs will be written to a subdirectory inside the action's working directory. Note that the download_tmp_dir will only be used if the action has racing as its exec strategy or it explicitly sets EnableAtomicDownloads=true. See proxy.proto for details.") debugPort = flag.Int("pprof_port", 0, "Enable pprof http server if not zero") cpuProfFile = flag.String("pprof_file", "", "Enable cpu pprof if not empty. Will not work on windows as reproxy shutdowns through an uncatchable sigkill.") @@ -286,6 +288,10 @@ func main() { } defer c.SaveToDisk() ts = c.TokenSource() + } else { + c := mustBuildCredentials() + defer c.SaveToDisk() + ts = c.TokenSource() } } var e *monitoring.Exporter @@ -471,6 +477,32 @@ func formatAuthError(ce *client.InitError) error { return status.Errorf(codes.Unauthenticated, errMsg+"\n%s", ce.Error()) } +// mustBuildCredentials either returns a valid auth.Credentials struct or exits +func mustBuildCredentials() *auth.Credentials { + if *experimentalCredentialsHelper != "" { + creds, err := auth.NewExternalCredentials(*experimentalCredentialsHelper, strings.Fields(*experimentalCredentialsHelperArgs), *credsFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Experimental credentials helper failed. Please try again or use application default credentials:%v", err) + os.Exit(auth.ExitCodeExternalTokenAuth) + } + return creds + } + m, err := auth.MechanismFromFlags() + if err != nil || m == auth.Unknown { + log.Errorf("Failed to determine auth mechanism: %v", err) + os.Exit(auth.ExitCodeNoAuth) + } + c, err := auth.NewCredentials(m, *credsFile) + if err != nil { + log.Errorf("Failed to initialize credentials: %v", err) + if aerr, ok := err.(*auth.Error); ok { + os.Exit(aerr.ExitCode) + } + os.Exit(auth.ExitCodeUnknown) + } + return c +} + func initializeLogger(mi *ignoremismatch.MismatchIgnorer, e *monitoring.Exporter) (*logger.Logger, error) { u := usage.New() if *auxiliaryMetadataPath != "" { diff --git a/go.mod b/go.mod index 37a5503a..93b99bed 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.1 github.com/gosuri/uilive v0.0.4 + github.com/hectane/go-acl v0.0.0-20230122075934-ca0b05cb1adb github.com/karrick/godirwalk v1.17.0 github.com/pkg/xattr v0.4.4 github.com/shirou/gopsutil/v3 v3.24.4 @@ -56,7 +57,6 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.5 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect - github.com/hectane/go-acl v0.0.0-20230122075934-ca0b05cb1adb // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.17.8 // indirect diff --git a/internal/pkg/auth/BUILD.bazel b/internal/pkg/auth/BUILD.bazel index 32d09355..25fbd494 100644 --- a/internal/pkg/auth/BUILD.bazel +++ b/internal/pkg/auth/BUILD.bazel @@ -2,11 +2,22 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "auth", - srcs = ["auth.go"], + srcs = [ + "auth.go", + "cache.go", + ], importpath = "github.com/bazelbuild/reclient/internal/pkg/auth", visibility = ["//:__subpackages__"], deps = [ + "//api/auth", + "//internal/pkg/features", + "//internal/pkg/pathtranslator", + "@com_github_bazelbuild_remote_apis_sdks//go/pkg/digest", "@com_github_golang_glog//:glog", + "@com_github_hectane_go_acl//:go-acl", + "@org_golang_google_grpc//credentials/oauth", + "@org_golang_google_protobuf//encoding/prototext", + "@org_golang_google_protobuf//types/known/timestamppb", "@org_golang_x_oauth2//:oauth2", "@org_golang_x_oauth2//google", ], @@ -16,4 +27,8 @@ go_test( name = "auth_test", srcs = ["auth_test.go"], embed = [":auth"], + deps = [ + "@org_golang_google_grpc//credentials/oauth", + "@org_golang_x_oauth2//:oauth2", + ], ) diff --git a/internal/pkg/auth/auth.go b/internal/pkg/auth/auth.go index 10a988d8..464d5ef9 100644 --- a/internal/pkg/auth/auth.go +++ b/internal/pkg/auth/auth.go @@ -16,17 +16,29 @@ package auth import ( + "bytes" "context" "encoding/json" + "errors" "flag" "fmt" + "os" + "os/exec" + "sort" "strconv" "strings" + "sync" "time" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + + "github.com/bazelbuild/reclient/internal/pkg/features" + "github.com/bazelbuild/reclient/internal/pkg/pathtranslator" + log "github.com/golang/glog" "golang.org/x/oauth2" googleOauth "golang.org/x/oauth2/google" + grpcOauth "google.golang.org/grpc/credentials/oauth" ) // Exit codes to indicate various causes of authentication failure. @@ -52,6 +64,9 @@ const ( // Unknown is an unknown auth mechanism. Unknown Mechanism = iota + // CredentialsHelper is using an externally provided binary to get credentials. + CredentialsHelper + // ADC is GCP's application default credentials authentication mechanism. ADC // GCE is authentication using GCE VM service accounts. @@ -67,6 +82,8 @@ func (m Mechanism) String() string { switch m { case Unknown: return "Unknown" + case CredentialsHelper: + return "CredentialsHelper" case ADC: return "ADC" case GCE: @@ -85,6 +102,9 @@ const ( CredshelperPathFlag = "experimental_credentials_helper" // CredshelperArgsFlag is the flag used to pass in the arguments to the credentials helper binary. CredshelperArgsFlag = "experimental_credentials_helper_args" + + // TODO(b/261172745): define these flags in reproxy rather than in the SDK. + // UseAppDefaultCredsFlag is used to authenticate with application default credentials. UseAppDefaultCredsFlag = "use_application_default_credentials" // UseExternalTokenFlag indicates the user will authenticate with a provided token. @@ -108,6 +128,8 @@ var stringAuthFlags = []string{ CredentialFileFlag, } +var nowFn = time.Now + // Error is an error occured during authenticating or initializing credentials. type Error struct { error @@ -115,6 +137,48 @@ type Error struct { ExitCode int } +type reusableCmd struct { + path string + args []string + digestOnce sync.Once + digest digest.Digest +} + +func newResubaleCmd(binary string, args []string) *reusableCmd { + cmd := exec.Command(binary, args...) + return &reusableCmd{ + path: cmd.Path, + args: args, + } +} + +func (r *reusableCmd) String() string { + return fmt.Sprintf("%s %v", r.path, strings.Join(r.args, " ")) +} + +func (r *reusableCmd) Cmd() *exec.Cmd { + return exec.Command(r.path, r.args...) +} + +func (r *reusableCmd) Digest() digest.Digest { + r.digestOnce.Do(func() { + chCmd := append(r.args, r.path) + sort.Strings(chCmd) + cmdStr := strings.Join(chCmd, ",") + r.digest = digest.NewFromBlob([]byte(cmdStr)) + }) + return r.digest +} + +// Credentials provides auth functionalities with a specific auth mechanism. +type Credentials struct { + m Mechanism + refreshExp time.Time + tokenSource *grpcOauth.TokenSource + credsHelperCmd *reusableCmd + credsFile string +} + // MechanismFromFlags returns an auth Mechanism based on flags currently set. func MechanismFromFlags() (Mechanism, error) { vals := make(map[string]bool, len(boolAuthFlags)+len(stringAuthFlags)) @@ -148,6 +212,17 @@ func MechanismFromFlags() (Mechanism, error) { return Unknown, &Error{fmt.Errorf("couldn't determine auth mechanism from flags %v", vals), ExitCodeNoAuth} } +// Cacheable returns true if this mechanism should be cached to disk +func (m Mechanism) Cacheable() bool { + if !features.GetConfig().EnableCredentialCache { + return false + } + if m == CredentialsHelper { + return true + } + return false +} + func boolFlagVal(flagName string) (bool, error) { if f := flag.Lookup(flagName); f != nil && f.Value.String() != "" { b, err := strconv.ParseBool(f.Value.String()) @@ -159,17 +234,138 @@ func boolFlagVal(flagName string) (bool, error) { return false, nil } -// UpdateStatus updates ADC credentials status if expired -func UpdateStatus(m Mechanism) (int, error) { - if m == ADC { - _, err := checkADCStatus() +// NewCredentials initializes a credentials object. +func NewCredentials(m Mechanism, credsFile string) (*Credentials, error) { + cc, err := loadFromDisk(credsFile) + if err != nil { + log.Warningf("Failed to load credentials cache file from %v: %v", credsFile, err) + return buildCredentials(cachedCredentials{m: m}, credsFile) + } + if cc.m != m { + log.Warningf("Cached mechanism (%v) is not the same as requested mechanism (%v). Will attempt to authenticate using the requested mechanism.", cc.m, m) + return buildCredentials(cachedCredentials{m: m}, credsFile) + } + return buildCredentials(cc, credsFile) +} + +func buildCredentials(baseCreds cachedCredentials, credsFile string) (*Credentials, error) { + if baseCreds.m == Unknown { + return nil, errors.New("cannot initialize credentials with unknown mechanism") + } + c := &Credentials{ + m: baseCreds.m, + refreshExp: baseCreds.refreshExp, + credsFile: credsFile, + } + return c, nil +} + +// build credentials obtained from the credentials helper. +func buildExternalCredentials(baseCreds cachedCredentials, credsFile string, credsHelperCmd *reusableCmd) *Credentials { + c := &Credentials{ + m: CredentialsHelper, + credsFile: credsFile, + credsHelperCmd: credsHelperCmd, + } + baseTs := &externalTokenSource{ + credsHelperCmd: credsHelperCmd, + } + c.tokenSource = &grpcOauth.TokenSource{ + // Wrap the base token source with a ReuseTokenSource so that we only + // generate new credentials when the current one is about to expire. + // This is needed because retrieving the token is expensive and some + // token providers have per hour rate limits. + TokenSource: oauth2.ReuseTokenSourceWithExpiry( + baseCreds.token, + baseTs, + // Refresh tokens 5 mins early to be safe + 5*time.Minute, + ), + } + return c +} + +func loadCredsFromDisk(credsFile string, credsHelperCmd *reusableCmd) (*Credentials, error) { + cc, err := loadFromDisk(credsFile) + if err != nil { + return nil, err + } + cmdDigest := credsHelperCmd.Digest() + if cc.credsHelperCmdDigest != cmdDigest.String() { + return nil, fmt.Errorf("cached credshelper command digest: %s is not the same as requested credshelper command digest: %s", + cc.credsHelperCmdDigest, cmdDigest.String()) + } + isExpired := cc.token != nil && cc.token.Expiry.Before(nowFn()) + if isExpired { + return nil, fmt.Errorf("cached token is expired at %v", cc.token.Expiry) + } + return buildExternalCredentials(cc, credsFile, credsHelperCmd), nil +} + +// SaveToDisk saves credentials to disk. +func (c *Credentials) SaveToDisk() { + if c == nil { + return + } + if !c.m.Cacheable() { + return + } + cc := cachedCredentials{m: c.m, refreshExp: c.refreshExp} + // Since c.tokenSource is always wrapped in a oauth2.ReuseTokenSourceWithExpiry + // this will return a cached credential if one exists. + t, err := c.tokenSource.Token() + if err != nil { + log.Errorf("Failed to get token to persist to disk: %v", err) + return + } + cc.token = t + if c.credsHelperCmd != nil { + cc.credsHelperCmdDigest = c.credsHelperCmd.Digest().String() + } + if err := saveToDisk(cc, c.credsFile); err != nil { + log.Errorf("Failed to save credentials to disk: %v", err) + } +} + +// RemoveFromDisk deletes the credentials cache on disk. +func (c *Credentials) RemoveFromDisk() { + if c == nil { + return + } + if err := os.Remove(c.credsFile); err != nil { + log.Errorf("Failed to remove credentials from disk: %v", err) + } +} + +// UpdateStatus updates the refresh expiry time if it is expired +func (c *Credentials) UpdateStatus() (int, error) { + if !nowFn().Before(c.refreshExp) && c.m == ADC { + exp, err := checkADCStatus() if err != nil { return ExitCodeAppDefCredsAuth, fmt.Errorf("application default credentials were invalid: %v", err) } + c.refreshExp = exp } return 0, nil } +// Mechanism returns the authentication mechanism of the credentials object. +func (c *Credentials) Mechanism() Mechanism { + if c == nil { + return None + } + return c.m +} + +// TokenSource returns a token source for this credentials instance. +// If this credential type does not produce credentials nil will be returned. +func (c *Credentials) TokenSource() *grpcOauth.TokenSource { + if c == nil { + return nil + } + return c.tokenSource +} + func checkADCStatus() (time.Time, error) { ts, err := googleOauth.FindDefaultCredentialsWithParams(context.Background(), googleOauth.CredentialsParams{ Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"}, @@ -209,3 +405,107 @@ func checkADCStatus() (time.Time, error) { } return token.Expiry, nil } + +// externaltokenSource uses a credentialsHelper to obtain gcp oauth tokens. +// This should be wrapped in a "golang.org/x/oauth2".ReuseTokenSource +// to avoid obtaining new tokens each time. +type externalTokenSource struct { + credsHelperCmd *reusableCmd +} + +// Token retrieves an oauth2 token from the external tokensource. +func (ts *externalTokenSource) Token() (*oauth2.Token, error) { + if ts == nil { + return nil, fmt.Errorf("empty tokensource") + } + tk, _, err := runCredsHelperCmd(ts.credsHelperCmd) + if err == nil { + log.Infof("'%s' credentials refreshed at %v, expires at %v", ts.credsHelperCmd, time.Now(), tk.Expiry) + } + return tk, err +} + +// NewExternalCredentials creates credentials obtained from a credshelper. +func NewExternalCredentials(credshelper string, credshelperArgs []string, credsFile string) (*Credentials, error) { + if credshelper == "execrel://" { + credshelperPath, err := pathtranslator.BinaryRelToAbs("credshelper") + if err != nil { + log.Fatalf("Specified %s=execrel:// but `credshelper` was not found in the same directory as `bootstrap` or `reproxy`: %v", CredshelperPathFlag, err) + } + credshelper = credshelperPath + } + credsHelperCmd := newResubaleCmd(credshelper, credshelperArgs) + if credsFile != "" { + creds, err := loadCredsFromDisk(credsFile, credsHelperCmd) + if err == nil { + return creds, nil + } + log.Warningf("Failed to load cached credentials, will fetch fresh credentials: %v", err) + } + tk, rexp, err := runCredsHelperCmd(credsHelperCmd) + if err != nil { + return nil, err + } + return buildExternalCredentials(cachedCredentials{token: tk, refreshExp: rexp}, credsFile, credsHelperCmd), nil +} + +func runCredsHelperCmd(credsHelperCmd *reusableCmd) (*oauth2.Token, time.Time, error) { + log.V(2).Infof("Running %v", credsHelperCmd) + var stdout, stderr bytes.Buffer + cmd := credsHelperCmd.Cmd() + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + out := stdout.String() + if stderr.String() != "" { + log.Errorf("Credentials helper warnings and errors: %v", stderr.String()) + } + if err != nil { + return nil, time.Time{}, err + } + token, expiry, refreshExpiry, err := parseTokenExpiryFromOutput(out) + return &oauth2.Token{ + AccessToken: token, + Expiry: expiry, + }, refreshExpiry, err +} + +// CredsHelperOut is the struct to record the json output from the credshelper. +type CredsHelperOut struct { + Token string `json:"token"` + Expiry string `json:"expiry"` + RefreshExpiry string `json:"refresh_expiry"` +} + +func parseTokenExpiryFromOutput(out string) (string, time.Time, time.Time, error) { + var ( + tk string + exp, rexp time.Time + chOut CredsHelperOut + ) + if err := json.Unmarshal([]byte(out), &chOut); err != nil { + return tk, exp, rexp, + fmt.Errorf("error while decoding credshelper output:%v", err) + } + tk = chOut.Token + if tk == "" { + return tk, exp, rexp, + fmt.Errorf("no token was printed by the credentials helper") + } + if chOut.Expiry != "" { + expiry, err := time.Parse(time.UnixDate, chOut.Expiry) + if err != nil { + return tk, exp, rexp, fmt.Errorf("invalid expiry format: %v (Expected time.UnixDate format)", chOut.Expiry) + } + exp = expiry + rexp = expiry + } + if chOut.RefreshExpiry != "" { + rexpiry, err := time.Parse(time.UnixDate, chOut.RefreshExpiry) + if err != nil { + return tk, exp, rexp, fmt.Errorf("invalid refresh expiry format: %v (Expected time.UnixDate format)", chOut.RefreshExpiry) + } + rexp = rexpiry + } + return tk, exp, rexp, nil +} diff --git a/internal/pkg/auth/auth_test.go b/internal/pkg/auth/auth_test.go index ee934a31..9779837a 100644 --- a/internal/pkg/auth/auth_test.go +++ b/internal/pkg/auth/auth_test.go @@ -16,8 +16,15 @@ package auth import ( "flag" + "fmt" "os" + "path/filepath" + "runtime" "testing" + "time" + + "golang.org/x/oauth2" + grpcOauth "google.golang.org/grpc/credentials/oauth" ) func TestMechanismFromFlags(t *testing.T) { @@ -70,3 +77,260 @@ func TestMechanismFromFlags(t *testing.T) { }) } } + +func TestCredentialsHelperCache(t *testing.T) { + dir, err := os.MkdirTemp("", "test") + if err != nil { + t.Errorf("failed to create the temp directory: %v", err) + } + t.Cleanup(func() { os.RemoveAll(dir) }) + cf := filepath.Join(dir, "reproxy.creds") + err = os.MkdirAll(filepath.Dir(cf), 0755) + if err != nil { + t.Errorf("failed to create dir for credentials file %q: %v", cf, err) + } + credsHelperCmd := newResubaleCmd("echo", []string{`{"token":"testToken", "expiry":"", "refresh_expiry":""}`}) + baseTs := &externalTokenSource{credsHelperCmd: credsHelperCmd} + ts := &grpcOauth.TokenSource{ + TokenSource: oauth2.ReuseTokenSourceWithExpiry( + &oauth2.Token{}, + baseTs, + 5*time.Minute, + ), + } + creds := &Credentials{ + m: CredentialsHelper, + refreshExp: time.Time{}, + tokenSource: ts, + credsFile: cf, + credsHelperCmd: credsHelperCmd, + } + creds.SaveToDisk() + c1, err := loadCredsFromDisk(cf, credsHelperCmd) + if err != nil { + t.Errorf("LoadCredsFromDisk failed: %v", err) + } + // Second load to make sure credentials were not purged. + c2, err := loadCredsFromDisk(cf, credsHelperCmd) + if err != nil { + t.Errorf("LoadCredsFromDisk failed: %v", err) + } + if creds.m != c1.m { + t.Errorf("Mechanism was cached incorrectly, got: %v, want: %v", c1.m, creds.m) + } + if creds.m != c2.m { + t.Errorf("Mechanism was cached incorrectly, got: %v, want: %v", c2.m, creds.m) + } +} + +func TestExternalToken(t *testing.T) { + expiry := time.Now().Truncate(time.Second) + exp := expiry.Format(time.UnixDate) + tk := "testToken" + var ( + credshelper string + credshelperArgs []string + ) + if runtime.GOOS == "windows" { + tf, err := os.CreateTemp("", "testexternaltoken.json") + if err != nil { + t.Fatalf("Unable to create temporary file: %v", err) + } + chJSON := fmt.Sprintf(`{"token":"%v","expiry":"%s","refresh_expiry":""}`, tk, exp) + if _, err := tf.Write([]byte(chJSON)); err != nil { + t.Fatalf("Unable to write to file %v: %v", tf.Name(), err) + } + credshelper = "cmd" + credshelperArgs = []string{ + "/c", + "cat", + tf.Name(), + } + } else { + credshelper = "echo" + credshelperArgs = []string{fmt.Sprintf(`{"token":"%v","expiry":"%s","refresh_expiry":""}`, tk, exp)} + } + + credsHelperCmd := newResubaleCmd(credshelper, credshelperArgs) + ts := &externalTokenSource{ + credsHelperCmd: credsHelperCmd, + } + oauth2tk, err := ts.Token() + if err != nil { + t.Errorf("externalTokenSource.Token() returned an error: %v", err) + } + if oauth2tk.AccessToken != tk { + t.Errorf("externalTokenSource.Token() returned token=%s, want=%s", oauth2tk.AccessToken, tk) + } + if !oauth2tk.Expiry.Equal(expiry) { + t.Errorf("externalTokenSource.Token() returned expiry=%s, want=%s", oauth2tk.Expiry, exp) + } +} + +func TestExternalTokenRefresh(t *testing.T) { + tmp := t.TempDir() + tokenFile := filepath.Join(tmp, "reproxy.creds") + var ( + credshelper string + credshelperArgs []string + ) + if runtime.GOOS == "windows" { + credshelper = "cmd" + credshelperArgs = []string{ + "/c", + "cat", + tokenFile, + } + } else { + credshelper = "cat" + credshelperArgs = []string{ + tokenFile, + } + } + credsHelperCmd := newResubaleCmd(credshelper, credshelperArgs) + ts := &externalTokenSource{ + credsHelperCmd: credsHelperCmd, + } + for _, token := range []string{"testToken", "testTokenRefresh"} { + expiry := time.Now().Truncate(time.Second) + writeTokenFile(t, tokenFile, token, expiry) + + oauth2tk, err := ts.Token() + if err != nil { + t.Errorf("externalTokenSource.Token() returned an error: %v", err) + } + if oauth2tk.AccessToken != token { + t.Errorf("externalTokenSource.Token() returned token=%s, want=%s", oauth2tk.AccessToken, token) + } + if !oauth2tk.Expiry.Equal(expiry) { + t.Errorf("externalTokenSource.Token() returned expiry=%s, want=%s", oauth2tk.Expiry, expiry) + } + } +} + +func writeTokenFile(t *testing.T, path, token string, expiry time.Time) { + t.Helper() + f, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + t.Fatalf("Unable to open file %v: %v", path, err) + } + defer f.Close() + chJSON := fmt.Sprintf(`{"token":"%v","expiry":"%s","refresh_expiry":""}`, token, expiry.Format(time.UnixDate)) + if _, err := f.Write([]byte(chJSON)); err != nil { + t.Fatalf("Unable to write to file %v: %v", f.Name(), err) + } +} + +func TestNewExternalCredentials(t *testing.T) { + testToken := "token" + exp := time.Now().Add(time.Hour).Truncate(time.Second) + expStr := exp.String() + unixExp := exp.Format(time.UnixDate) + tests := []struct { + name string + wantErr bool + checkExp bool + credshelperOut string + }{{ + name: "No Token", + wantErr: true, + credshelperOut: fmt.Sprintf(`{"token":"","expiry":"","refresh_expiry":""}`), + }, { + name: "Credshelper Command Passed - No Expiry", + credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"","refresh_expiry":""}`, testToken), + }, { + name: "Credshelper Command Passed - Expiry", + checkExp: true, + credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"%v","refresh_expiry":""}`, testToken, unixExp), + }, { + name: "Credshelper Command Passed - Refresh Expiry", + checkExp: true, + credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"%v","refresh_expiry":"%v"}`, testToken, unixExp, unixExp), + }, { + name: "Wrong Expiry Format", + wantErr: true, + credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"%v","refresh_expiry":"%v"}`, testToken, expStr, expStr), + }} + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + credshelper string + credshelperArgs []string + ) + if runtime.GOOS == "windows" { + tf, err := os.CreateTemp("", "testnewexternalcreds.json") + if err != nil { + t.Fatalf("Unable to create temporary file: %v", err) + } + if _, err := tf.Write([]byte(test.credshelperOut)); err != nil { + t.Fatalf("Unable to write to file %v: %v", tf.Name(), err) + } + credshelper = "cmd" + credshelperArgs = []string{ + "/c", + "cat", + tf.Name(), + } + } else { + credshelper = "echo" + credshelperArgs = []string{test.credshelperOut} + } + + c, err := NewExternalCredentials(credshelper, credshelperArgs, "") + if test.wantErr && err == nil { + t.Fatalf("NewExternalCredentials did not return an error.") + } + if !test.wantErr { + if err != nil { + t.Fatalf("NewExternalCredentials returned an error: %v", err) + } + if c.m != CredentialsHelper { + t.Errorf("NewExternalCredentials returned credentials with mechanism=%v, want=%v", c.m, CredentialsHelper) + } + if c.tokenSource == nil { + t.Fatalf("NewExternalCredentials returned credentials with a nil tokensource.") + } + tk, err := c.tokenSource.Token() + if err != nil { + t.Fatalf("tokensource.Token() call failed: %v", err) + } + if tk.AccessToken != testToken { + t.Fatalf("tokensource.Token() gave token=%s, want=%s", + tk.AccessToken, testToken) + } + if test.checkExp && !exp.Equal(tk.Expiry) { + t.Fatalf("tokensource.Token() gave expiry=%v, want=%v", + tk.Expiry, exp) + } + } + }) + } +} + +func TestReusableCmd(t *testing.T) { + binary := "echo" + args := []string{"hello"} + cmd := newResubaleCmd(binary, args) + + output, err := cmd.Cmd().CombinedOutput() + if err != nil { + t.Errorf("Command failed: %v", err) + } else if string(output) != "hello\n" { + t.Errorf("Command returned unexpected output: %s", output) + } + + output, err = cmd.Cmd().CombinedOutput() + if err != nil { + t.Errorf("Command failed second time: %v", err) + } else if string(output) != "hello\n" { + t.Errorf("Command returned unexpected output second time: %s", output) + } +} + +func TestReusableCmdDigest(t *testing.T) { + cmd1 := newResubaleCmd("echo", []string{"Hello"}) + cmd2 := newResubaleCmd("echo", []string{"Bye"}) + if cmd1.Digest() == cmd2.Digest() { + t.Errorf("`%s` and `%s` have the same digest", cmd1, cmd2) + } +} diff --git a/internal/pkg/auth/cache.go b/internal/pkg/auth/cache.go new file mode 100644 index 00000000..0e93f7c1 --- /dev/null +++ b/internal/pkg/auth/cache.go @@ -0,0 +1,160 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "os" + "time" + + apb "github.com/bazelbuild/reclient/api/auth" + + log "github.com/golang/glog" + "github.com/hectane/go-acl" + "golang.org/x/oauth2" + "google.golang.org/protobuf/encoding/prototext" + tspb "google.golang.org/protobuf/types/known/timestamppb" +) + +// CachedCredentials are the credentials cached to disk. +type cachedCredentials struct { + m Mechanism + refreshExp time.Time + token *oauth2.Token + credsHelperCmdDigest string +} + +func loadFromDisk(tf string) (cachedCredentials, error) { + if tf == "" { + return cachedCredentials{}, nil + } + blob, err := os.ReadFile(tf) + if err != nil { + return cachedCredentials{}, err + } + cPb := &apb.Credentials{} + if err := prototext.Unmarshal(blob, cPb); err != nil { + return cachedCredentials{}, err + } + accessToken := cPb.GetToken() + exp := TimeFromProto(cPb.GetExpiry()) + var token *oauth2.Token + if accessToken != "" && !exp.IsZero() { + token = &oauth2.Token{ + AccessToken: accessToken, + Expiry: exp, + } + } + c := cachedCredentials{ + m: protoToMechanism(cPb.GetMechanism()), + token: token, + refreshExp: TimeFromProto(cPb.GetRefreshExpiry()), + credsHelperCmdDigest: cPb.GetCredsHelperCmdDigest(), + } + if !c.m.Cacheable() { + log.Infof("Purging credentials, non-cacheable mechanism: %v", c.m) + // Purge non cacheable credentials from disk. + if err := os.Remove(tf); err != nil { + log.Warningf("Unable to remove cached credentials file %q, err=%v", tf, err) + } + // TODO(b/2028466): Do not use the non-cacheable mechanism even for the + // current run. + } + log.Infof("Loaded cached credentials of type %v, expires at %v", c.m, exp) + return c, nil +} + +func saveToDisk(c cachedCredentials, tf string) error { + if tf == "" { + return nil + } + cPb := &apb.Credentials{} + cPb.Mechanism = mechanismToProto(c.m) + if c.token != nil { + cPb.Token = c.token.AccessToken + cPb.Expiry = TimeToProto(c.token.Expiry) + cPb.CredsHelperCmdDigest = c.credsHelperCmdDigest + } + if !c.refreshExp.IsZero() { + cPb.RefreshExpiry = TimeToProto(c.refreshExp) + } + f, err := os.Create(tf) + if err != nil { + return err + } + // Only owner can read/write the credential cache. + // This is consistent with gcloud's credentials.db. + // os.OpenFile(..., 0600) is not used because it does not properly set ACLs on windows. + if err := acl.Chmod(tf, 0600); err != nil { + return err + } + defer f.Close() + f.WriteString(prototext.Format(cPb)) + log.Infof("Saved cached credentials of type %v, expires at %v to %v", c.m, cPb.Expiry, tf) + return nil +} + +func mechanismToProto(m Mechanism) apb.AuthMechanism_Value { + switch m { + case Unknown: + return apb.AuthMechanism_UNSPECIFIED + case CredentialsHelper: + return apb.AuthMechanism_CREDENTIALSHELPER + case ADC: + return apb.AuthMechanism_ADC + case GCE: + return apb.AuthMechanism_GCE + case CredentialFile: + return apb.AuthMechanism_CREDENTIAL_FILE + case None: + return apb.AuthMechanism_NONE + default: + return apb.AuthMechanism_UNSPECIFIED + } +} + +func protoToMechanism(p apb.AuthMechanism_Value) Mechanism { + switch p { + case apb.AuthMechanism_UNSPECIFIED: + return Unknown + case apb.AuthMechanism_CREDENTIALSHELPER: + return CredentialsHelper + case apb.AuthMechanism_ADC: + return ADC + case apb.AuthMechanism_GCE: + return GCE + case apb.AuthMechanism_NONE: + return None + case apb.AuthMechanism_CREDENTIAL_FILE: + return CredentialFile + default: + return Unknown + } +} + +// TimeToProto converts a valid time.Time into a proto Timestamp. +func TimeToProto(t time.Time) *tspb.Timestamp { + if t.IsZero() { + return nil + } + return tspb.New(t) +} + +// TimeFromProto converts a valid Timestamp proto into a time.Time. +func TimeFromProto(tPb *tspb.Timestamp) time.Time { + if tPb == nil { + return time.Time{} + } + return tPb.AsTime() +}