diff --git a/internal/commands/auth.go b/internal/commands/auth.go index 832fa66bc..bdbdb06b1 100644 --- a/internal/commands/auth.go +++ b/internal/commands/auth.go @@ -5,6 +5,7 @@ import ( "log" "github.com/MakeNowJust/heredoc" + "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/google/uuid" @@ -38,7 +39,7 @@ type ClientCreated struct { Secret string `json:"secret"` } -func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command { +func NewAuthCommand(authWrapper wrappers.AuthWrapper, telemetryWrapper wrappers.TelemetryWrapper) *cobra.Command { authCmd := &cobra.Command{ Use: "auth", Short: "Validate authentication and create OAuth2 credentials", @@ -110,14 +111,29 @@ func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command { `, ), }, - RunE: validLogin(), + RunE: validLogin(telemetryWrapper), } authCmd.AddCommand(createClientCmd, validLoginCmd) return authCmd } -func validLogin() func(cmd *cobra.Command, args []string) error { +func validLogin(telemetryWrapper wrappers.TelemetryWrapper) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { + defer func() { + logger.PrintIfVerbose("Calling GetUniqueId func") + uniqueID := wrappers.GetUniqueID() + if uniqueID != "" { + logger.PrintIfVerbose("Set unique id: " + uniqueID) + err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{ + UniqueID: uniqueID, + Type: "authentication", + SubType: "authentication", + }) + if err != nil { + logger.PrintIfVerbose("Failed to send telemetry data: " + err.Error()) + } + } + }() clientID := viper.GetString(params.AccessKeyIDConfigKey) clientSecret := viper.GetString(params.AccessKeySecretConfigKey) apiKey := viper.GetString(params.AstAPIKey) diff --git a/internal/commands/root.go b/internal/commands/root.go index 9c56cb812..dc9587c52 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -205,7 +205,7 @@ func NewAstCLI( ) versionCmd := util.NewVersionCommand() - authCmd := NewAuthCommand(authWrapper) + authCmd := NewAuthCommand(authWrapper, telemetryWrapper) utilsCmd := util.NewUtilsCommand( gitHubWrapper, azureWrapper, diff --git a/internal/commands/telemetry.go b/internal/commands/telemetry.go index 3b5bbefe0..cbf3d1f3a 100644 --- a/internal/commands/telemetry.go +++ b/internal/commands/telemetry.go @@ -2,6 +2,7 @@ package commands import ( "github.com/MakeNowJust/heredoc" + "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/pkg/errors" @@ -58,7 +59,8 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm scanType, _ := cmd.Flags().GetString("scan-type") status, _ := cmd.Flags().GetString("status") totalCount, _ := cmd.Flags().GetInt("total-count") - + uniqueID := wrappers.GetUniqueID() + logger.PrintIfVerbose("unique id: " + uniqueID) err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{ AIProvider: aiProvider, ProblemSeverity: problemSeverity, @@ -69,6 +71,7 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm ScanType: scanType, Status: status, TotalCount: totalCount, + UniqueID: uniqueID, }) if err != nil { diff --git a/internal/params/envs.go b/internal/params/envs.go index 19dc0f3c7..51d29f632 100644 --- a/internal/params/envs.go +++ b/internal/params/envs.go @@ -80,4 +80,5 @@ const ( RiskManagementPathEnv = "CX_RISK_MANAGEMENT_PATH" ConfigFilePathEnv = "CX_CONFIG_FILE_PATH" RealtimeScannerPathEnv = "CX_REALTIME_SCANNER_PATH" + UniqueIDEnv = "CX_UNIQUE_ID" ) diff --git a/internal/params/keys.go b/internal/params/keys.go index 839b13e53..ef7bd8156 100644 --- a/internal/params/keys.go +++ b/internal/params/keys.go @@ -79,4 +79,5 @@ var ( RiskManagementPathKey = strings.ToLower(RiskManagementPathEnv) ConfigFilePathKey = strings.ToLower(ConfigFilePathEnv) RealtimeScannerPathKey = strings.ToLower(RealtimeScannerPathEnv) + UniqueIDConfigKey = strings.ToLower(UniqueIDEnv) ) diff --git a/internal/wrappers/client.go b/internal/wrappers/client.go index 032eb4e4b..213188be4 100644 --- a/internal/wrappers/client.go +++ b/internal/wrappers/client.go @@ -12,6 +12,7 @@ import ( "net/http/httptrace" "net/url" "os" + "os/user" "runtime" "strings" "sync" @@ -20,11 +21,13 @@ import ( applicationErrors "github.com/checkmarx/ast-cli/internal/constants/errors" "github.com/checkmarx/ast-cli/internal/logger" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "github.com/pkg/errors" "github.com/spf13/viper" commonParams "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" "github.com/checkmarx/ast-cli/internal/wrappers/kerberos" "github.com/checkmarx/ast-cli/internal/wrappers/ntlm" ) @@ -126,12 +129,21 @@ func retryHTTPForIAMRequest(requestFunc func() (*http.Response, error), retries return nil, err } -func setAgentNameAndOrigin(req *http.Request) { +func setAgentNameAndOrigin(req *http.Request, isAuth bool) { agentStr := viper.GetString(commonParams.AgentNameKey) + "/" + commonParams.Version req.Header.Set("User-Agent", agentStr) originStr := viper.GetString(commonParams.OriginKey) req.Header.Set("Cx-Origin", originStr) + logger.PrintIfVerbose("getting unique id") + + if !isAuth { + uniqueID := GetUniqueID() + if uniqueID != "" { + req.Header.Set("UniqueId", uniqueID) + logger.PrintIfVerbose("unique id: " + uniqueID) + } + } } func GetClient(timeout uint) *http.Client { @@ -375,7 +387,7 @@ func SendHTTPRequestByFullURLContentLength( req.ContentLength = contentLength } client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) if auth { enrichWithOath2Credentials(req, accessToken, bearerFormat) } @@ -427,7 +439,7 @@ func SendHTTPRequestPasswordAuth(method string, body io.Reader, timeout uint, us } req, err := http.NewRequest(method, u, body) client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, true) if err != nil { return nil, err } @@ -464,7 +476,7 @@ func HTTPRequestWithQueryParams( } req, err := http.NewRequest(method, u, body) client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) if err != nil { return nil, err } @@ -512,7 +524,7 @@ func SendHTTPRequestWithJSONContentType(method, path string, body io.Reader, aut } req, err := http.NewRequest(method, fullURL, body) client := GetClient(timeout) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) req.Header.Add("Content-Type", jsonContentType) if err != nil { return nil, err @@ -645,7 +657,7 @@ func writeCredentialsToCache(accessToken string) { func getNewToken(credentialsPayload, authServerURI string) (string, error) { payload := strings.NewReader(credentialsPayload) req, err := http.NewRequest(http.MethodPost, authServerURI, payload) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, true) if err != nil { return "", err } @@ -972,3 +984,85 @@ func extractAZPFromToken(astToken string) (string, error) { } return azp, nil } + +func GetUniqueID() string { + var uniqueID string + isAllowed := false + accessToken, err := GetAccessToken() + if err != nil { + logger.PrintIfVerbose("Failed to get access token") + return "" + } + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + token, _, err := parser.ParseUnverified(accessToken, jwt.MapClaims{}) + if err != nil { + logger.PrintIfVerbose("Failed to parse JWT token " + err.Error()) + return "" + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + logger.PrintIfVerbose("Failed to get JWT claims") + return "" + } + + astLicense, ok := claims["ast-license"].(map[string]interface{}) + if !ok { + logger.PrintIfVerbose("Failed to get ast-license from claims") + return "" + } + + licenseData, ok := astLicense["LicenseData"].(map[string]interface{}) + if !ok { + logger.PrintIfVerbose("Failed to get LicenseData from ast-license") + return "" + } + + allowedEngines, ok := licenseData["allowedEngines"].([]interface{}) + if !ok { + logger.PrintIfVerbose("Failed to get allowedEngines from LicenseData") + return "" + } + + for _, engine := range allowedEngines { + engineStr, ok := engine.(string) + if !ok { + continue + } + if strings.EqualFold(engineStr, "Checkmarx Developer Assist") { + isAllowed = true + break + } + } + + if !isAllowed { + logger.PrintIfVerbose("User does not not have permission to standalone dev asists feature") + return "" + } + uniqueID = viper.GetString(commonParams.UniqueIDConfigKey) + if uniqueID != "" { + return uniqueID + } + logger.PrintIfVerbose("Generating new unique id") + currentUser, err := user.Current() + if err != nil { + logger.PrintIfVerbose("Failed to get user: " + err.Error()) + return "" + } + username := currentUser.Username + username = strings.TrimSpace(username) + logger.PrintIfVerbose("Username to be used for unique id: " + username) + if strings.Contains(username, "\\") { + username = strings.Split(username, "\\")[1] + } + uniqueID = uuid.New().String() + "_" + username + + logger.PrintIfVerbose("Unique id: " + uniqueID) + viper.Set(commonParams.UniqueIDConfigKey, uniqueID) + configFilePath, _ := configuration.GetConfigFilePath() + err = configuration.SafeWriteSingleConfigKeyString(configFilePath, commonParams.UniqueIDConfigKey, uniqueID) + if err != nil { + logger.PrintIfVerbose("Failed to write config: " + err.Error()) + return "" + } + return uniqueID +} diff --git a/internal/wrappers/client_test.go b/internal/wrappers/client_test.go index b8a45f0d1..e75e9e9b0 100644 --- a/internal/wrappers/client_test.go +++ b/internal/wrappers/client_test.go @@ -192,7 +192,7 @@ func TestSetAgentNameAndOrigin(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - setAgentNameAndOrigin(req) + setAgentNameAndOrigin(req, false) userAgent := req.Header.Get("User-Agent") origin := req.Header.Get("origin") diff --git a/internal/wrappers/telemetry.go b/internal/wrappers/telemetry.go index 58e8a5b73..b3e58781f 100644 --- a/internal/wrappers/telemetry.go +++ b/internal/wrappers/telemetry.go @@ -10,6 +10,7 @@ type DataForAITelemetry struct { ScanType string `json:"scanType"` Status string `json:"status"` TotalCount int `json:"totalCount"` + UniqueID string `json:"uniqueId"` } type TelemetryWrapper interface {