Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions internal/commands/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"

"github.com/MakeNowJust/heredoc"
"github.com/checkmarx/ast-cli/internal/logger"
"github.com/checkmarx/ast-cli/internal/params"
"github.com/checkmarx/ast-cli/internal/wrappers"
"github.com/google/uuid"
Expand Down Expand Up @@ -38,7 +39,7 @@ type ClientCreated struct {
Secret string `json:"secret"`
}

func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command {
func NewAuthCommand(authWrapper wrappers.AuthWrapper, telemetryWrapper wrappers.TelemetryWrapper) *cobra.Command {
authCmd := &cobra.Command{
Use: "auth",
Short: "Validate authentication and create OAuth2 credentials",
Expand Down Expand Up @@ -110,14 +111,29 @@ func NewAuthCommand(authWrapper wrappers.AuthWrapper) *cobra.Command {
`,
),
},
RunE: validLogin(),
RunE: validLogin(telemetryWrapper),
}
authCmd.AddCommand(createClientCmd, validLoginCmd)
return authCmd
}

func validLogin() func(cmd *cobra.Command, args []string) error {
func validLogin(telemetryWrapper wrappers.TelemetryWrapper) func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error {
defer func() {
logger.PrintIfVerbose("Calling GetUniqueId func")
uniqueID := wrappers.GetUniqueID()
if uniqueID != "" {
logger.PrintIfVerbose("Set unique id: " + uniqueID)
err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{
UniqueID: uniqueID,
Type: "authentication",
SubType: "authentication",
})
if err != nil {
logger.PrintIfVerbose("Failed to send telemetry data: " + err.Error())
}
}
}()
clientID := viper.GetString(params.AccessKeyIDConfigKey)
clientSecret := viper.GetString(params.AccessKeySecretConfigKey)
apiKey := viper.GetString(params.AstAPIKey)
Expand Down
2 changes: 1 addition & 1 deletion internal/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func NewAstCLI(
)

versionCmd := util.NewVersionCommand()
authCmd := NewAuthCommand(authWrapper)
authCmd := NewAuthCommand(authWrapper, telemetryWrapper)
utilsCmd := util.NewUtilsCommand(
gitHubWrapper,
azureWrapper,
Expand Down
5 changes: 4 additions & 1 deletion internal/commands/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package commands

import (
"github.com/MakeNowJust/heredoc"
"github.com/checkmarx/ast-cli/internal/logger"
"github.com/checkmarx/ast-cli/internal/params"
"github.com/checkmarx/ast-cli/internal/wrappers"
"github.com/pkg/errors"
Expand Down Expand Up @@ -58,7 +59,8 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm
scanType, _ := cmd.Flags().GetString("scan-type")
status, _ := cmd.Flags().GetString("status")
totalCount, _ := cmd.Flags().GetInt("total-count")

uniqueID := wrappers.GetUniqueID()
logger.PrintIfVerbose("unique id: " + uniqueID)
err := telemetryWrapper.SendAIDataToLog(&wrappers.DataForAITelemetry{
AIProvider: aiProvider,
ProblemSeverity: problemSeverity,
Expand All @@ -69,6 +71,7 @@ func runTelemetryAI(telemetryWrapper wrappers.TelemetryWrapper) func(*cobra.Comm
ScanType: scanType,
Status: status,
TotalCount: totalCount,
UniqueID: uniqueID,
})

if err != nil {
Expand Down
1 change: 1 addition & 0 deletions internal/params/envs.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,5 @@ const (
RiskManagementPathEnv = "CX_RISK_MANAGEMENT_PATH"
ConfigFilePathEnv = "CX_CONFIG_FILE_PATH"
RealtimeScannerPathEnv = "CX_REALTIME_SCANNER_PATH"
UniqueIDEnv = "CX_UNIQUE_ID"
)
1 change: 1 addition & 0 deletions internal/params/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,5 @@ var (
RiskManagementPathKey = strings.ToLower(RiskManagementPathEnv)
ConfigFilePathKey = strings.ToLower(ConfigFilePathEnv)
RealtimeScannerPathKey = strings.ToLower(RealtimeScannerPathEnv)
UniqueIDConfigKey = strings.ToLower(UniqueIDEnv)
)
106 changes: 100 additions & 6 deletions internal/wrappers/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http/httptrace"
"net/url"
"os"
"os/user"
"runtime"
"strings"
"sync"
Expand All @@ -20,11 +21,13 @@ import (
applicationErrors "github.com/checkmarx/ast-cli/internal/constants/errors"
"github.com/checkmarx/ast-cli/internal/logger"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"

"github.com/pkg/errors"
"github.com/spf13/viper"

commonParams "github.com/checkmarx/ast-cli/internal/params"
"github.com/checkmarx/ast-cli/internal/wrappers/configuration"
"github.com/checkmarx/ast-cli/internal/wrappers/kerberos"
"github.com/checkmarx/ast-cli/internal/wrappers/ntlm"
)
Expand Down Expand Up @@ -126,12 +129,21 @@ func retryHTTPForIAMRequest(requestFunc func() (*http.Response, error), retries
return nil, err
}

func setAgentNameAndOrigin(req *http.Request) {
func setAgentNameAndOrigin(req *http.Request, isAuth bool) {
agentStr := viper.GetString(commonParams.AgentNameKey) + "/" + commonParams.Version
req.Header.Set("User-Agent", agentStr)

originStr := viper.GetString(commonParams.OriginKey)
req.Header.Set("Cx-Origin", originStr)
logger.PrintIfVerbose("getting unique id")

if !isAuth {
uniqueID := GetUniqueID()
if uniqueID != "" {
req.Header.Set("UniqueId", uniqueID)
logger.PrintIfVerbose("unique id: " + uniqueID)
}
}
}

func GetClient(timeout uint) *http.Client {
Expand Down Expand Up @@ -375,7 +387,7 @@ func SendHTTPRequestByFullURLContentLength(
req.ContentLength = contentLength
}
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)
if auth {
enrichWithOath2Credentials(req, accessToken, bearerFormat)
}
Expand Down Expand Up @@ -427,7 +439,7 @@ func SendHTTPRequestPasswordAuth(method string, body io.Reader, timeout uint, us
}
req, err := http.NewRequest(method, u, body)
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -464,7 +476,7 @@ func HTTPRequestWithQueryParams(
}
req, err := http.NewRequest(method, u, body)
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -512,7 +524,7 @@ func SendHTTPRequestWithJSONContentType(method, path string, body io.Reader, aut
}
req, err := http.NewRequest(method, fullURL, body)
client := GetClient(timeout)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)
req.Header.Add("Content-Type", jsonContentType)
if err != nil {
return nil, err
Expand Down Expand Up @@ -645,7 +657,7 @@ func writeCredentialsToCache(accessToken string) {
func getNewToken(credentialsPayload, authServerURI string) (string, error) {
payload := strings.NewReader(credentialsPayload)
req, err := http.NewRequest(http.MethodPost, authServerURI, payload)
setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, true)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -972,3 +984,85 @@ func extractAZPFromToken(astToken string) (string, error) {
}
return azp, nil
}

func GetUniqueID() string {
var uniqueID string
isAllowed := false
accessToken, err := GetAccessToken()
if err != nil {
logger.PrintIfVerbose("Failed to get access token")
return ""
}
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
token, _, err := parser.ParseUnverified(accessToken, jwt.MapClaims{})
if err != nil {
logger.PrintIfVerbose("Failed to parse JWT token " + err.Error())
return ""
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.PrintIfVerbose("Failed to get JWT claims")
return ""
}

astLicense, ok := claims["ast-license"].(map[string]interface{})
if !ok {
logger.PrintIfVerbose("Failed to get ast-license from claims")
return ""
}

licenseData, ok := astLicense["LicenseData"].(map[string]interface{})
if !ok {
logger.PrintIfVerbose("Failed to get LicenseData from ast-license")
return ""
}

allowedEngines, ok := licenseData["allowedEngines"].([]interface{})
if !ok {
logger.PrintIfVerbose("Failed to get allowedEngines from LicenseData")
return ""
}

for _, engine := range allowedEngines {
engineStr, ok := engine.(string)
if !ok {
continue
}
if strings.EqualFold(engineStr, "Checkmarx Developer Assist") {
isAllowed = true
break
}
}

if !isAllowed {
logger.PrintIfVerbose("User does not not have permission to standalone dev asists feature")
return ""
}
uniqueID = viper.GetString(commonParams.UniqueIDConfigKey)
if uniqueID != "" {
return uniqueID
}
logger.PrintIfVerbose("Generating new unique id")
currentUser, err := user.Current()
if err != nil {
logger.PrintIfVerbose("Failed to get user: " + err.Error())
return ""
}
username := currentUser.Username
username = strings.TrimSpace(username)
logger.PrintIfVerbose("Username to be used for unique id: " + username)
if strings.Contains(username, "\\") {
username = strings.Split(username, "\\")[1]
}
uniqueID = uuid.New().String() + "_" + username

logger.PrintIfVerbose("Unique id: " + uniqueID)
viper.Set(commonParams.UniqueIDConfigKey, uniqueID)
configFilePath, _ := configuration.GetConfigFilePath()
err = configuration.SafeWriteSingleConfigKeyString(configFilePath, commonParams.UniqueIDConfigKey, uniqueID)
if err != nil {
logger.PrintIfVerbose("Failed to write config: " + err.Error())
return ""
}
return uniqueID
}
2 changes: 1 addition & 1 deletion internal/wrappers/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func TestSetAgentNameAndOrigin(t *testing.T) {

req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)

setAgentNameAndOrigin(req)
setAgentNameAndOrigin(req, false)

userAgent := req.Header.Get("User-Agent")
origin := req.Header.Get("origin")
Expand Down
1 change: 1 addition & 0 deletions internal/wrappers/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type DataForAITelemetry struct {
ScanType string `json:"scanType"`
Status string `json:"status"`
TotalCount int `json:"totalCount"`
UniqueID string `json:"uniqueId"`
}

type TelemetryWrapper interface {
Expand Down
Loading