diff --git a/internal/flink/command.go b/internal/flink/command.go index 12f507b843..d48f7f081c 100644 --- a/internal/flink/command.go +++ b/internal/flink/command.go @@ -28,6 +28,11 @@ func New(cfg *config.Config, prerunner pcmd.PreRunner) *cobra.Command { c := &command{pcmd.NewAuthenticatedCLICommand(cmd, prerunner)} + cmd.PersistentFlags().Bool("insecure-skip-verify", false, "Skip TLS certificate verification for Flink gateway connections.") + cmd.PersistentFlags().String("certificate-authority-path", "", "Self-signed certificate chain in PEM format.") + cobra.CheckErr(cmd.PersistentFlags().MarkHidden("insecure-skip-verify")) + cobra.CheckErr(cmd.PersistentFlags().MarkHidden("certificate-authority-path")) + // On-prem commands are able to run with or without login. Accordingly, set the pre-runner. if cfg.IsOnPremLogin() { c = &command{pcmd.NewAuthenticatedWithMDSCLICommand(cmd, prerunner)} diff --git a/internal/flink/command_shell.go b/internal/flink/command_shell.go index 7204ae8083..b1047ad712 100644 --- a/internal/flink/command_shell.go +++ b/internal/flink/command_shell.go @@ -1,6 +1,7 @@ package flink import ( + "crypto/tls" "net/url" "strings" @@ -18,6 +19,7 @@ import ( "github.com/confluentinc/cli/v4/pkg/jwt" "github.com/confluentinc/cli/v4/pkg/log" ppanic "github.com/confluentinc/cli/v4/pkg/panic-recovery" + "github.com/confluentinc/cli/v4/pkg/utils" ) func (c *command) newShellCommand(prerunner pcmd.PreRunner, cfg *config.Config) *cobra.Command { @@ -231,7 +233,23 @@ func (c *command) startFlinkSqlClient(prerunner pcmd.PreRunner, cmd *cobra.Comma LSPBaseUrl: lspBaseUrl, } - return client.StartApp(flinkGatewayClient, c.authenticated(prerunner.Authenticated(c.AuthenticatedCLICommand), cmd, jwtValidator), opts, reportUsage(cmd, c.Config, unsafeTrace)) + insecureSkipVerify, err := c.Flags().GetBool("insecure-skip-verify") + if err != nil { + return err + } + caCertPath, err := c.Flags().GetString("certificate-authority-path") + if err != nil { + return err + } + caCertPool, err := utils.GetEnrichedCACertPool(caCertPath) + if err != nil { + return err + } + + log.CliLogger.Debugf("Insecure skip verify: %t\n", insecureSkipVerify) + tlsClientConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify, RootCAs: caCertPool} + + return client.StartApp(flinkGatewayClient, c.authenticated(prerunner.Authenticated(c.AuthenticatedCLICommand), cmd, jwtValidator), opts, reportUsage(cmd, c.Config, unsafeTrace), tlsClientConfig) } func (c *command) startFlinkSqlClientOnPrem(prerunner pcmd.PreRunner, cmd *cobra.Command) error { @@ -302,10 +320,26 @@ func (c *command) startWithLocalMode(configKeys, configValues []string) error { return err } - gatewayClient := ccloudv2.NewFlinkGatewayClient(appOptions.GetGatewayUrl(), c.Version.UserAgent, appOptions.GetUnsafeTrace(), "authToken") + insecureSkipVerify, err := c.Flags().GetBool("insecure-skip-verify") + if err != nil { + return err + } + caCertPath, err := c.Flags().GetString("certificate-authority-path") + if err != nil { + return err + } + caCertPool, err := utils.GetEnrichedCACertPool(caCertPath) + if err != nil { + return err + } + + log.CliLogger.Debugf("Insecure skip verify: %t\n", insecureSkipVerify) + tlsClientConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify, RootCAs: caCertPool} + + gatewayClient := ccloudv2.NewFlinkGatewayClient(appOptions.GetGatewayUrl(), c.Version.UserAgent, appOptions.GetUnsafeTrace(), "authToken", tlsClientConfig) appOptions.Context = c.Context - return client.StartApp(gatewayClient, func() error { return nil }, *appOptions, func() {}) + return client.StartApp(gatewayClient, func() error { return nil }, *appOptions, func() {}, tlsClientConfig) } func (c *command) getFlinkLanguageServiceUrl(gatewayClient *ccloudv2.FlinkGatewayClient) (string, error) { diff --git a/pkg/ccloudv2/flink_gateway.go b/pkg/ccloudv2/flink_gateway.go index 9575168579..9b20a29922 100644 --- a/pkg/ccloudv2/flink_gateway.go +++ b/pkg/ccloudv2/flink_gateway.go @@ -2,6 +2,7 @@ package ccloudv2 import ( "context" + "crypto/tls" "fmt" "net/http" @@ -33,10 +34,10 @@ type FlinkGatewayClient struct { AuthToken string } -func NewFlinkGatewayClient(url, userAgent string, unsafeTrace bool, authToken string) *FlinkGatewayClient { +func NewFlinkGatewayClient(url, userAgent string, unsafeTrace bool, authToken string, tlsClientConfig *tls.Config) *FlinkGatewayClient { cfg := flinkgatewayv1.NewConfiguration() cfg.Debug = unsafeTrace - cfg.HTTPClient = NewRetryableHttpClientWithRedirect(unsafeTrace, checkRedirect) + cfg.HTTPClient = NewRetryableHttpClientWithRedirect(unsafeTrace, tlsClientConfig, checkRedirect) cfg.Servers = flinkgatewayv1.ServerConfigurations{{URL: url}} cfg.UserAgent = userAgent diff --git a/pkg/ccloudv2/utils.go b/pkg/ccloudv2/utils.go index 3a5794f1b8..01e823fe74 100644 --- a/pkg/ccloudv2/utils.go +++ b/pkg/ccloudv2/utils.go @@ -2,6 +2,7 @@ package ccloudv2 import ( "context" + "crypto/tls" "fmt" "net/http" "net/url" @@ -84,8 +85,12 @@ func NewRetryableHttpClient(cfg *config.Config, unsafeTrace bool) *http.Client { return client.StandardClient() } -func NewRetryableHttpClientWithRedirect(unsafeTrace bool, checkRedirect func(*http.Request, []*http.Request) error) *http.Client { +func NewRetryableHttpClientWithRedirect(unsafeTrace bool, tlsClientConfig *tls.Config, checkRedirect func(*http.Request, []*http.Request) error) *http.Client { client := retryablehttp.NewClient() + transport := &http.Transport{ + TLSClientConfig: tlsClientConfig, + } + client.HTTPClient.Transport = transport client.Logger = plog.NewLeveledLogger(unsafeTrace) client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { if resp == nil { diff --git a/pkg/cmd/authenticated_cli_command.go b/pkg/cmd/authenticated_cli_command.go index aab6b390e3..e00b65236c 100644 --- a/pkg/cmd/authenticated_cli_command.go +++ b/pkg/cmd/authenticated_cli_command.go @@ -1,6 +1,7 @@ package cmd import ( + "crypto/tls" "fmt" purl "net/url" "os" @@ -86,13 +87,29 @@ func (c *AuthenticatedCLICommand) GetFlinkGatewayClient(computePoolOnly bool) (* return nil, err } + insecureSkipVerify, err := c.Flags().GetBool("insecure-skip-verify") + if err != nil { + return nil, err + } + caCertPath, err := c.Flags().GetString("certificate-authority-path") + if err != nil { + return nil, err + } + caCertPool, err := utils.GetEnrichedCACertPool(caCertPath) + if err != nil { + return nil, err + } + dataplaneToken, err := auth.GetDataplaneToken(c.Context) if err != nil { return nil, err } + log.CliLogger.Debugf("Insecure skip verify: %t\n", insecureSkipVerify) + tlsClientConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify, RootCAs: caCertPool} + log.CliLogger.Debugf("The final url used for setting up Flink dataplane client is: %s\n", url) - c.flinkGatewayClient = ccloudv2.NewFlinkGatewayClient(url, c.Version.UserAgent, unsafeTrace, dataplaneToken) + c.flinkGatewayClient = ccloudv2.NewFlinkGatewayClient(url, c.Version.UserAgent, unsafeTrace, dataplaneToken, tlsClientConfig) } return c.flinkGatewayClient, nil diff --git a/pkg/flink/app/application.go b/pkg/flink/app/application.go index ea5e9a312b..525a4fb76c 100644 --- a/pkg/flink/app/application.go +++ b/pkg/flink/app/application.go @@ -1,6 +1,7 @@ package app import ( + "crypto/tls" "sync" "time" @@ -47,7 +48,7 @@ func synchronizedTokenRefresh(tokenRefreshFunc func() error) func() error { } } -func StartApp(gatewayClient ccloudv2.GatewayClientInterface, tokenRefreshFunc func() error, appOptions types.ApplicationOptions, reportUsageFunc func()) error { +func StartApp(gatewayClient ccloudv2.GatewayClientInterface, tokenRefreshFunc func() error, appOptions types.ApplicationOptions, reportUsageFunc func(), tlsClientConfig *tls.Config) error { synchronizedTokenRefreshFunc := synchronizedTokenRefresh(tokenRefreshFunc) getAuthToken := func() string { if authErr := synchronizedTokenRefreshFunc(); authErr != nil { @@ -70,7 +71,7 @@ func StartApp(gatewayClient ccloudv2.GatewayClientInterface, tokenRefreshFunc fu // Instantiate LSP handlerCh := make(chan *jsonrpc2.Request) // This is the channel used for the messages received by the language to be passed through to the input controller - lspClient, _, err := lsp.NewInitializedLspClient(getAuthToken, appOptions.GetLSPBaseUrl(), appOptions.GetOrganizationId(), appOptions.GetEnvironmentId(), handlerCh) + lspClient, _, err := lsp.NewInitializedLspClient(getAuthToken, appOptions.GetLSPBaseUrl(), appOptions.GetOrganizationId(), appOptions.GetEnvironmentId(), tlsClientConfig, handlerCh) if err != nil { log.CliLogger.Errorf("Failed to connect to the language service. Check your network."+ " If you're using private networking, you might still be able to submit queries. If that's the case and you"+ diff --git a/pkg/flink/internal/store/store_test.go b/pkg/flink/internal/store/store_test.go index 699cae8218..421b894d26 100644 --- a/pkg/flink/internal/store/store_test.go +++ b/pkg/flink/internal/store/store_test.go @@ -2,6 +2,7 @@ package store import ( "context" + "crypto/tls" "fmt" "net/http" "reflect" @@ -60,7 +61,8 @@ func TestStoreProcessLocalStatement(t *testing.T) { stores := make([]types.StoreInterface, 2) // Cloud store - client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken") + tlsClientConfig := &tls.Config{} + client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig) mockAppController := mock.NewMockApplicationControllerInterface(gomock.NewController(t)) appOptions := types.ApplicationOptions{ OrganizationId: "orgId", diff --git a/pkg/flink/internal/store/store_utils_test.go b/pkg/flink/internal/store/store_utils_test.go index 6c925c76ba..4a3c973797 100644 --- a/pkg/flink/internal/store/store_utils_test.go +++ b/pkg/flink/internal/store/store_utils_test.go @@ -1,6 +1,7 @@ package store import ( + "crypto/tls" "fmt" "testing" @@ -38,7 +39,8 @@ func TestRemoveStatementTerminator(t *testing.T) { func TestProcessSetStatement(t *testing.T) { // Create a new store - client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken") + tlsClientConfig := &tls.Config{} + client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig) appOptions := &types.ApplicationOptions{ Cloud: true, EnvironmentName: "env-123", @@ -133,7 +135,8 @@ func TestProcessSetStatement(t *testing.T) { func TestProcessResetStatement(t *testing.T) { // Create a new store - client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken") + tlsClientConfig := &tls.Config{} + client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig) appOptions := types.ApplicationOptions{ Cloud: true, OrganizationId: "orgId", @@ -210,7 +213,8 @@ func TestProcessResetStatement(t *testing.T) { func TestProcessUseStatement(t *testing.T) { // Create a new store - client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken") + tlsClientConfig := &tls.Config{} + client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig) appOptions := types.ApplicationOptions{ OrganizationId: "orgId", EnvironmentName: "envName", diff --git a/pkg/flink/lsp/lsp_completer_ws.go b/pkg/flink/lsp/lsp_completer_ws.go index 1e46d3631b..fa45325634 100644 --- a/pkg/flink/lsp/lsp_completer_ws.go +++ b/pkg/flink/lsp/lsp_completer_ws.go @@ -2,6 +2,7 @@ package lsp import ( "context" + "crypto/tls" "fmt" "net/http" "sync" @@ -17,13 +18,14 @@ import ( type WebsocketLSPClient struct { sync.Mutex - baseUrl string - getAuthToken func() string - organizationId string - environmentId string - handlerCh chan *jsonrpc2.Request - conn *jsonrpc2.Conn - lspClient LspInterface + baseUrl string + getAuthToken func() string + organizationId string + environmentId string + handlerCh chan *jsonrpc2.Request + conn *jsonrpc2.Conn + lspClient LspInterface + tlsClientConfig *tls.Config } func (w *WebsocketLSPClient) Initialize() (*lsp.InitializeResult, error) { @@ -73,7 +75,7 @@ func (w *WebsocketLSPClient) refreshWebsocketConnection() { } // we only update client and conn if there was no error, otherwise we leave them as is - if lspClient, conn, err := NewLSPConnection(w.baseUrl, w.getAuthToken(), w.organizationId, w.environmentId, w.handlerCh); err == nil { + if lspClient, conn, err := NewLSPConnection(w.baseUrl, w.getAuthToken(), w.organizationId, w.environmentId, w.tlsClientConfig, w.handlerCh); err == nil { w.lspClient = lspClient w.conn = conn _, err := InitLspClient(w.lspClient) @@ -86,8 +88,8 @@ func (w *WebsocketLSPClient) refreshWebsocketConnection() { } } -func NewInitializedLspClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, handlerCh chan *jsonrpc2.Request) (LspInterface, string, error) { - client, _, err := NewLSPClient(getAuthToken, baseUrl, organizationId, environmentId, handlerCh) +func NewInitializedLspClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, tlsClientConfig *tls.Config, handlerCh chan *jsonrpc2.Request) (LspInterface, string, error) { + client, _, err := NewLSPClient(getAuthToken, baseUrl, organizationId, environmentId, tlsClientConfig, handlerCh) if err != nil { return nil, "", err } @@ -116,27 +118,28 @@ func InitLspClient(client LspInterface) (string, error) { return string(docUri), nil } -func NewLSPClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, handlerCh chan *jsonrpc2.Request) (*WebsocketLSPClient, *jsonrpc2.Conn, error) { - lspClient, conn, err := NewLSPConnection(baseUrl, getAuthToken(), organizationId, environmentId, handlerCh) +func NewLSPClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, tlsClientConfig *tls.Config, handlerCh chan *jsonrpc2.Request) (*WebsocketLSPClient, *jsonrpc2.Conn, error) { + lspClient, conn, err := NewLSPConnection(baseUrl, getAuthToken(), organizationId, environmentId, tlsClientConfig, handlerCh) if err != nil { return nil, conn, err } websocketClient := &WebsocketLSPClient{ - baseUrl: baseUrl, - getAuthToken: getAuthToken, - organizationId: organizationId, - environmentId: environmentId, - handlerCh: handlerCh, - lspClient: lspClient, - conn: conn, + baseUrl: baseUrl, + getAuthToken: getAuthToken, + organizationId: organizationId, + environmentId: environmentId, + handlerCh: handlerCh, + lspClient: lspClient, + conn: conn, + tlsClientConfig: tlsClientConfig, } return websocketClient, websocketClient.conn, nil } -func NewLSPConnection(baseUrl, authToken, organizationId, environmentId string, handlerCh chan *jsonrpc2.Request) (*LSPClient, *jsonrpc2.Conn, error) { - stream, err := NewWSObjectStream(baseUrl, authToken, organizationId, environmentId) +func NewLSPConnection(baseUrl, authToken, organizationId, environmentId string, tlsClientConfig *tls.Config, handlerCh chan *jsonrpc2.Request) (*LSPClient, *jsonrpc2.Conn, error) { + stream, err := NewWSObjectStream(baseUrl, authToken, organizationId, environmentId, tlsClientConfig) if err != nil { log.CliLogger.Debugf("Error dialing websocket: %v\n", err) return nil, nil, err @@ -155,12 +158,15 @@ func NewLSPConnection(baseUrl, authToken, organizationId, environmentId string, return lspClient, conn, nil } -func NewWSObjectStream(socketUrl, authToken, organizationId, environmentId string) (jsonrpc2.ObjectStream, error) { +func NewWSObjectStream(socketUrl, authToken, organizationId, environmentId string, tlsClientConfig *tls.Config) (jsonrpc2.ObjectStream, error) { requestHeaders := http.Header{} requestHeaders.Add("Authorization", fmt.Sprintf("Bearer %s", authToken)) requestHeaders.Add("Organization-ID", organizationId) requestHeaders.Add("Environment-ID", environmentId) - conn, _, err := websocket.DefaultDialer.Dial(socketUrl, requestHeaders) + dialer := &websocket.Dialer{ + TLSClientConfig: tlsClientConfig, + } + conn, _, err := dialer.Dial(socketUrl, requestHeaders) if err != nil { return nil, err } diff --git a/pkg/utils/cert_utils.go b/pkg/utils/cert_utils.go index ea5cc88a88..7d1d88a932 100644 --- a/pkg/utils/cert_utils.go +++ b/pkg/utils/cert_utils.go @@ -154,6 +154,53 @@ func SelfSignedCertClient(caCertReader io.Reader, clientCert tls.Certificate) (* return client, nil } +// Could refactor the above CustomCAAndClientCertClient to use this, but for now leaving it separate to avoid breaking changes +func GetEnrichedCACertPool(caCertPath string) (*x509.CertPool, error) { + // Load system certs (or initialize a new one if unable to load system) as a certificate pool + caCertPool, err := x509.SystemCertPool() + if err != nil { + log.CliLogger.Warnf("Unable to load system certificates; continuing with custom certificates only") + } + log.CliLogger.Tracef("Loaded certificate pool from system") + if caCertPool == nil { + log.CliLogger.Tracef("(System certificate pool was blank)") + caCertPool = x509.NewCertPool() + } + + // If the provided path is not empty, and is a valid file, add it to the certificate pool + if caCertPath == "" { + log.CliLogger.Tracef("No custom CA certificate specified, using system certs only") + return caCertPool, nil + } + + // Validate and read the custom certificate file + absPath, err := filepath.Abs(caCertPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve certificate path: %w", err) + } + + log.CliLogger.Debugf("Attempting to load certificate from absolute path %s", absPath) + caCertFile, err := os.Open(absPath) + if err != nil { + return nil, fmt.Errorf("failed to open certificate file: %w", err) + } + defer caCertFile.Close() + + customCaCerts, err := io.ReadAll(caCertFile) + if err != nil { + return nil, fmt.Errorf("failed to read certificate: %w", err) + } + log.CliLogger.Tracef("Successfully read CA certificate") + + // Append custom certs to the system pool + if ok := caCertPool.AppendCertsFromPEM(customCaCerts); !ok { + return nil, fmt.Errorf("no valid certificates found in file: %s", absPath) + } + log.CliLogger.Tracef("Successfully appended new certificate to the pool") + + return caCertPool, nil +} + func isEmptyClientCert(cert tls.Certificate) bool { return cert.Certificate == nil && cert.Leaf == nil && cert.OCSPStaple == nil && cert.PrivateKey == nil && cert.SignedCertificateTimestamps == nil && cert.SupportedSignatureAlgorithms == nil }