Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions orbit/cmd/orbit/orbit.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/fleetdm/fleet/v4/pkg/certificate"
"github.com/fleetdm/fleet/v4/pkg/file"
"github.com/fleetdm/fleet/v4/pkg/fleethttpsig"
"github.com/fleetdm/fleet/v4/pkg/open"
retrypkg "github.com/fleetdm/fleet/v4/pkg/retry"
"github.com/fleetdm/fleet/v4/pkg/secure"
"github.com/fleetdm/fleet/v4/server/fleet"
Expand Down Expand Up @@ -1119,6 +1120,16 @@ func main() {
return fmt.Errorf("error new orbit client: %w", err)
}

// Set the function that will be called to open the SSO window if an enroll
// request returns an "end user authentication required" error.
orbitClient.SetOpenSSOWindowFunc(func() error {
err = open.Browser(fleetURL + "/mdm/sso?initiator=setup_experience&host_uuid=" + orbitHostInfo.HardwareUUID)
if err != nil {
return fmt.Errorf("opening browser: %w", err)
}
return nil
})

// If the server can't be reached, we want to fail quickly on any blocking network calls
// so that desktop can be launched as soon as possible.
serverIsReachable := orbitClient.Ping() == nil
Expand Down
43 changes: 43 additions & 0 deletions pkg/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,32 @@ import (
"time"
)

// ErrorOutcome tells retry.Do how to react to a returned error.
// Use with WithErrorFilter to control retry behavior.
type ErrorOutcome int

const (
// ErrorOutcomeNormalRetry indicates that the error is retryable
// and the retry loop should continue as normal.
ErrorOutcomeNormalRetry ErrorOutcome = iota
// ErrorOutcomeResetAttempts indicates that the retry attempts counter
// should be reset to zero.
// Useful for hijacking the retry cycle to retry indefinitely
// until a certain condition is met.
ErrorOutcomeResetAttempts
// ErrorOutcomeIgnore indicates that the error should be ignored
// and the retry loop should exit successfully.
ErrorOutcomeIgnore
// ErrorOutcomeDoNotRetry indicates that the error is not retryable
// and the retry loop should exit with the error.
ErrorOutcomeDoNotRetry
)

type config struct {
initialInterval time.Duration
backoffMultiplier int
maxAttempts int
errorFilter func(error) ErrorOutcome
}

// Option allows to configure the behavior of retry.Do
Expand Down Expand Up @@ -37,6 +59,14 @@ func WithMaxAttempts(a int) Option {
}
}

// WithErrorFilter sets a function that maps errors to retry outcomes.
// The filter is evaluated before max‑attempts/backoff handling.
func WithErrorFilter(f func(error) ErrorOutcome) Option {
return func(c *config) {
c.errorFilter = f
}
}

// Do executes the provided function, if the function returns a
// non-nil error it performs a retry according to the options
// provided.
Expand All @@ -62,6 +92,19 @@ func Do(fn func() error, opts ...Option) error {
if err == nil {
return nil
}
if cfg.errorFilter != nil {
switch cfg.errorFilter(err) {
case ErrorOutcomeIgnore:
return nil
case ErrorOutcomeResetAttempts:
attempts = 0
backoff = 1
case ErrorOutcomeDoNotRetry:
return err
default:
// continue with normal retry
}
}

if cfg.maxAttempts != 0 && attempts >= cfg.maxAttempts {
return err
Expand Down
75 changes: 75 additions & 0 deletions pkg/retry/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,79 @@ func TestRetryDo(t *testing.T) {
require.NoError(t, err)
require.Equal(t, maxAttempts, count)
})

t.Run("with error filter (test ignore)", func(t *testing.T) {
count := 0
err := Do(func() error {
count++
if count == 1 {
return errors.New("normal")
}
if count == 2 {
return errors.New("reset")
}
if count == 3 {
return errors.New("ignore")
}
return nil
},
WithInterval(50*time.Millisecond),
// We should actually run 3 times, but since one
// of the errors causes a reset, we set max attempts to 2
// to ensure that the reset logic is exercised.
WithMaxAttempts(2),
WithErrorFilter(func(err error) ErrorOutcome {
if err.Error() == "normal" {
return ErrorOutcomeNormalRetry
}
if err.Error() == "reset" {
return ErrorOutcomeResetAttempts
}
if err.Error() == "ignore" {
return ErrorOutcomeIgnore
}
return ErrorOutcomeDoNotRetry
}),
)

require.NoError(t, err)
require.Equal(t, 3, count)
})

t.Run("with error filter (test noretry)", func(t *testing.T) {
count := 0
err := Do(func() error {
count++
if count == 1 {
return errors.New("normal")
}
if count == 2 {
return errors.New("reset")
}
if count == 3 {
return errors.New("stop")
}
return nil
},
WithInterval(50*time.Millisecond),
// We should only actually run 3 times, setting this to 10
// tests that the DoNotRetry logic is exercised.
WithMaxAttempts(10),
WithErrorFilter(func(err error) ErrorOutcome {
if err.Error() == "normal" {
return ErrorOutcomeNormalRetry
}
if err.Error() == "reset" {
return ErrorOutcomeResetAttempts
}
if err.Error() == "stop" {
return ErrorOutcomeDoNotRetry
}
return ErrorOutcomeNormalRetry
}),
)

require.ErrorContains(t, err, "stop")
require.Equal(t, 3, count)
})
}
3 changes: 3 additions & 0 deletions server/service/base_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ func (bc *baseClient) parseResponse(verb, path string, response *http.Response,
if strings.Contains(errText, "password reset required") {
return ErrPasswordResetRequired
}
if strings.Contains(errText, "END_USER_AUTH_REQUIRED") {
return ErrEndUserAuthRequired
}
return ErrUnauthenticated
case http.StatusPaymentRequired:
return ErrMissingLicense
Expand Down
3 changes: 3 additions & 0 deletions server/service/base_client_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ var (
ErrUnauthenticated = errors.New("unauthenticated, or invalid token")
ErrPasswordResetRequired = errors.New("Password reset required. Please sign into the Fleet UI to update your password, then log in again with: fleetctl login.")
ErrMissingLicense = errors.New("missing or invalid license")
// ErrEndUserAuthRequired is returned when an action (such as enrolling a device)
// requires end user authentication
ErrEndUserAuthRequired = errors.New("end user authentication required")
)

type SetupAlreadyErr interface {
Expand Down
69 changes: 50 additions & 19 deletions server/service/orbit_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ type OrbitClient struct {
// If set then it will be deleted on HTTP 401 errors from Fleet and it will cause ExecuteConfigReceivers
// to terminate to trigger a restart.
hostIdentityCertPath string

// initiatedIdpAuth is a flag indicating whether a window has been opened
// to the sign-on page for the organization's Identity Provider.
initiatedIdpAuth bool

// openSSOWindow is a function that opens a browser window to the SSO URL.
openSSOWindow func() error
}

// time-to-live for config cache
Expand All @@ -79,6 +86,10 @@ type configCache struct {
err error
}

func (oc *OrbitClient) SetOpenSSOWindowFunc(f func() error) {
oc.openSSOWindow = f
}

func (oc *OrbitClient) request(verb string, path string, params interface{}, resp interface{}) error {
return oc.requestWithExternal(verb, path, params, resp, false)
}
Expand Down Expand Up @@ -537,38 +548,58 @@ func (oc *OrbitClient) getNodeKeyOrEnroll() (string, error) {
default:
return "", fmt.Errorf("read orbit node key file: %w", err)
}
var (
orbitNodeKey_ string
endpointDoesNotExist bool
)
var orbitNodeKey_ string
if err := retry.Do(
func() error {
var err error
orbitNodeKey_, err = oc.enrollAndWriteNodeKeyFile()
switch {
case err == nil:
return nil
case errors.Is(err, notFoundErr{}):
// Do not retry if the endpoint does not exist.
endpointDoesNotExist = true
return nil
default:
logging.LogErrIfEnvNotSet(constant.SilenceEnrollLogErrorEnvVar, err, "enroll failed, retrying")
return err
}
return err
Comment on lines 545 to +555
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this logic is now moved into the WithErrorFilter() function, which was implemented so that we could retry indefinitely when waiting for the user to log in to the IdP.

},
// The below configuration means the following retry intervals (exponential backoff):
// 10s, 20s, 40s, 80s, 160s and then return the failure (max attempts = 6)
// thus executing no more than ~6 enroll request failures every ~5 minutes.
retry.WithInterval(orbitEnrollRetryInterval()),
retry.WithMaxAttempts(constant.OrbitEnrollMaxRetries),
retry.WithBackoffMultiplier(constant.OrbitEnrollBackoffMultiplier),
retry.WithErrorFilter(func(err error) (errorOutcome retry.ErrorOutcome) {
log.Info().Err(err).Msg("orbit enroll attempt failed")
switch {
case errors.Is(err, notFoundErr{}):
// Do not retry if the endpoint does not exist.
return retry.ErrorOutcomeDoNotRetry
case errors.Is(err, ErrEndUserAuthRequired):
// If we get an ErrEndUserAuthRequired error, then the user
// needs to authenticate with the identity provider.
//
// Open a browser window to the sign-on page and
// then keep retrying until they authenticate.
log.Debug().Msg("enroll unauthenticated, waiting for end-user to authenticate via SSO")
if !oc.initiatedIdpAuth {
if oc.openSSOWindow == nil {
log.Error().Msg("SSO window open function not set")
return retry.ErrorOutcomeNormalRetry
}
log.Debug().Msg("opening SSO window")
openWindowErr := oc.openSSOWindow()
if openWindowErr != nil {
log.Error().Err(openWindowErr).Msg("opening SSO window")
return retry.ErrorOutcomeNormalRetry
}
oc.initiatedIdpAuth = true
}
// Sleep for 20 seconds, making the total retry interval 30 seconds
time.Sleep(20 * time.Second)
return retry.ErrorOutcomeResetAttempts
default:
logging.LogErrIfEnvNotSet(constant.SilenceEnrollLogErrorEnvVar, err, "enroll failed, retrying")
return retry.ErrorOutcomeNormalRetry
}
}),
); err != nil {
if errors.Is(err, notFoundErr{}) {
return "", errors.New("enroll endpoint does not exist")
}
return "", fmt.Errorf("orbit node key enroll failed, attempts=%d", constant.OrbitEnrollMaxRetries)
}
if endpointDoesNotExist {
return "", errors.New("enroll endpoint does not exist")
}
return orbitNodeKey_, nil
}

Expand Down
10 changes: 10 additions & 0 deletions tools/saml/config.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
'simplesaml.nameidattribute' => 'email',
);

# Use for local testing of devices on the same network.
$metadata['mdm.host.docker.internal'] = array(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably fine but wanted to make sure you knew you were committing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Meant to comment on it. It makes testing this easier, I added testing instructions which rely on this 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least I have a very strong memory of adding those instructions, but I don't see them 😩 . Wonder if I didn't save.

'AssertionConsumerService' => [
'https://host.docker.internal:8080/api/v1/fleet/mdm/sso/callback',
],
'NameIDFormat' => 'urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddres',
'simplesaml.nameidattribute' => 'email',
);


# Used for testing when sso_settings.entity_id ("sso.test.com") is different than
# server_settings.server_url (usually "https://localhost:8080").
$metadata['sso.test.com'] = array(
Expand Down
Loading