diff --git a/cmd/root.go b/cmd/root.go index 61f0fc5..e7553de 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -71,19 +71,25 @@ func getKernelClient(cmd *cobra.Command) kernel.Client { return util.GetKernelClient(cmd) } -// isAuthExempt returns true if the command or any of its parents should skip auth. +// isAuthExempt returns true if the command should skip auth. func isAuthExempt(cmd *cobra.Command) bool { - // bare root command does not need auth + // Root command doesn't need auth if cmd == rootCmd { return true } - for c := cmd; c != nil; c = c.Parent() { - switch c.Name() { - case "login", "logout", "auth", "help", "completion", - "create": - return true - } + + // Walk up to find the top-level command (direct child of rootCmd) + topLevel := cmd + for topLevel.Parent() != nil && topLevel.Parent() != rootCmd { + topLevel = topLevel.Parent() + } + + // Check if the top-level command is in the exempt list + switch topLevel.Name() { + case "login", "logout", "auth", "help", "completion", "create": + return true } + return false } diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 0000000..381e8ac --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,79 @@ +package cmd + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func TestIsAuthExempt(t *testing.T) { + tests := []struct { + name string + cmd *cobra.Command + expected bool + }{ + { + name: "root command is exempt", + cmd: rootCmd, + expected: true, + }, + { + name: "login command is exempt", + cmd: loginCmd, + expected: true, + }, + { + name: "logout command is exempt", + cmd: logoutCmd, + expected: true, + }, + { + name: "top-level create command is exempt", + cmd: createCmd, + expected: true, + }, + { + name: "browser-pools create subcommand requires auth", + cmd: browserPoolsCreateCmd, + expected: false, + }, + { + name: "browsers create subcommand requires auth", + cmd: browsersCreateCmd, + expected: false, + }, + { + name: "profiles create subcommand requires auth", + cmd: profilesCreateCmd, + expected: false, + }, + { + name: "browser-pools list requires auth", + cmd: browserPoolsListCmd, + expected: false, + }, + { + name: "browsers list requires auth", + cmd: browsersListCmd, + expected: false, + }, + { + name: "deploy command requires auth", + cmd: deployCmd, + expected: false, + }, + { + name: "invoke command requires auth", + cmd: invokeCmd, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isAuthExempt(tt.cmd) + assert.Equal(t, tt.expected, result, "isAuthExempt(%s) = %v, want %v", tt.cmd.Name(), result, tt.expected) + }) + } +}