Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/flink/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ func New(cfg *config.Config, prerunner pcmd.PreRunner) *cobra.Command {

c := &command{pcmd.NewAuthenticatedCLICommand(cmd, prerunner)}

cmd.PersistentFlags().Bool("insecure-skip-verify", false, "Skip TLS certificate verification for Flink gateway connections.")
cmd.PersistentFlags().String("certificate-authority-path", "", "Self-signed certificate chain in PEM format.")
cobra.CheckErr(cmd.PersistentFlags().MarkHidden("insecure-skip-verify"))
cobra.CheckErr(cmd.PersistentFlags().MarkHidden("certificate-authority-path"))

// On-prem commands are able to run with or without login. Accordingly, set the pre-runner.
if cfg.IsOnPremLogin() {
c = &command{pcmd.NewAuthenticatedWithMDSCLICommand(cmd, prerunner)}
Expand Down
40 changes: 37 additions & 3 deletions internal/flink/command_shell.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package flink

import (
"crypto/tls"
"net/url"
"strings"

Expand All @@ -18,6 +19,7 @@ import (
"github.com/confluentinc/cli/v4/pkg/jwt"
"github.com/confluentinc/cli/v4/pkg/log"
ppanic "github.com/confluentinc/cli/v4/pkg/panic-recovery"
"github.com/confluentinc/cli/v4/pkg/utils"
)

func (c *command) newShellCommand(prerunner pcmd.PreRunner, cfg *config.Config) *cobra.Command {
Expand Down Expand Up @@ -231,7 +233,23 @@ func (c *command) startFlinkSqlClient(prerunner pcmd.PreRunner, cmd *cobra.Comma
LSPBaseUrl: lspBaseUrl,
}

return client.StartApp(flinkGatewayClient, c.authenticated(prerunner.Authenticated(c.AuthenticatedCLICommand), cmd, jwtValidator), opts, reportUsage(cmd, c.Config, unsafeTrace))
insecureSkipVerify, err := c.Flags().GetBool("insecure-skip-verify")
if err != nil {
return err
}
caCertPath, err := c.Flags().GetString("certificate-authority-path")
if err != nil {
return err
}
caCertPool, err := utils.GetEnrichedCACertPool(caCertPath)
if err != nil {
return err
}

log.CliLogger.Debugf("Insecure skip verify: %t\n", insecureSkipVerify)
tlsClientConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify, RootCAs: caCertPool}

return client.StartApp(flinkGatewayClient, c.authenticated(prerunner.Authenticated(c.AuthenticatedCLICommand), cmd, jwtValidator), opts, reportUsage(cmd, c.Config, unsafeTrace), tlsClientConfig)
}

func (c *command) startFlinkSqlClientOnPrem(prerunner pcmd.PreRunner, cmd *cobra.Command) error {
Expand Down Expand Up @@ -302,10 +320,26 @@ func (c *command) startWithLocalMode(configKeys, configValues []string) error {
return err
}

gatewayClient := ccloudv2.NewFlinkGatewayClient(appOptions.GetGatewayUrl(), c.Version.UserAgent, appOptions.GetUnsafeTrace(), "authToken")
insecureSkipVerify, err := c.Flags().GetBool("insecure-skip-verify")
if err != nil {
return err
}
caCertPath, err := c.Flags().GetString("certificate-authority-path")
if err != nil {
return err
}
caCertPool, err := utils.GetEnrichedCACertPool(caCertPath)
if err != nil {
return err
}

log.CliLogger.Debugf("Insecure skip verify: %t\n", insecureSkipVerify)
tlsClientConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify, RootCAs: caCertPool}

gatewayClient := ccloudv2.NewFlinkGatewayClient(appOptions.GetGatewayUrl(), c.Version.UserAgent, appOptions.GetUnsafeTrace(), "authToken", tlsClientConfig)

appOptions.Context = c.Context
return client.StartApp(gatewayClient, func() error { return nil }, *appOptions, func() {})
return client.StartApp(gatewayClient, func() error { return nil }, *appOptions, func() {}, tlsClientConfig)
}

func (c *command) getFlinkLanguageServiceUrl(gatewayClient *ccloudv2.FlinkGatewayClient) (string, error) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/ccloudv2/flink_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ccloudv2

import (
"context"
"crypto/tls"
"fmt"
"net/http"

Expand Down Expand Up @@ -33,10 +34,10 @@ type FlinkGatewayClient struct {
AuthToken string
}

func NewFlinkGatewayClient(url, userAgent string, unsafeTrace bool, authToken string) *FlinkGatewayClient {
func NewFlinkGatewayClient(url, userAgent string, unsafeTrace bool, authToken string, tlsClientConfig *tls.Config) *FlinkGatewayClient {
cfg := flinkgatewayv1.NewConfiguration()
cfg.Debug = unsafeTrace
cfg.HTTPClient = NewRetryableHttpClientWithRedirect(unsafeTrace, checkRedirect)
cfg.HTTPClient = NewRetryableHttpClientWithRedirect(unsafeTrace, tlsClientConfig, checkRedirect)
cfg.Servers = flinkgatewayv1.ServerConfigurations{{URL: url}}
cfg.UserAgent = userAgent

Expand Down
7 changes: 6 additions & 1 deletion pkg/ccloudv2/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ccloudv2

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -84,8 +85,12 @@ func NewRetryableHttpClient(cfg *config.Config, unsafeTrace bool) *http.Client {
return client.StandardClient()
}

func NewRetryableHttpClientWithRedirect(unsafeTrace bool, checkRedirect func(*http.Request, []*http.Request) error) *http.Client {
func NewRetryableHttpClientWithRedirect(unsafeTrace bool, tlsClientConfig *tls.Config, checkRedirect func(*http.Request, []*http.Request) error) *http.Client {
client := retryablehttp.NewClient()
transport := &http.Transport{
TLSClientConfig: tlsClientConfig,
}
client.HTTPClient.Transport = transport
client.Logger = plog.NewLeveledLogger(unsafeTrace)
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
if resp == nil {
Expand Down
19 changes: 18 additions & 1 deletion pkg/cmd/authenticated_cli_command.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"crypto/tls"
"fmt"
purl "net/url"
"os"
Expand Down Expand Up @@ -86,13 +87,29 @@ func (c *AuthenticatedCLICommand) GetFlinkGatewayClient(computePoolOnly bool) (*
return nil, err
}

insecureSkipVerify, err := c.Flags().GetBool("insecure-skip-verify")
if err != nil {
return nil, err
}
caCertPath, err := c.Flags().GetString("certificate-authority-path")
if err != nil {
return nil, err
}
caCertPool, err := utils.GetEnrichedCACertPool(caCertPath)
if err != nil {
return nil, err
}

dataplaneToken, err := auth.GetDataplaneToken(c.Context)
if err != nil {
return nil, err
}

log.CliLogger.Debugf("Insecure skip verify: %t\n", insecureSkipVerify)
tlsClientConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify, RootCAs: caCertPool}

log.CliLogger.Debugf("The final url used for setting up Flink dataplane client is: %s\n", url)
c.flinkGatewayClient = ccloudv2.NewFlinkGatewayClient(url, c.Version.UserAgent, unsafeTrace, dataplaneToken)
c.flinkGatewayClient = ccloudv2.NewFlinkGatewayClient(url, c.Version.UserAgent, unsafeTrace, dataplaneToken, tlsClientConfig)
}

return c.flinkGatewayClient, nil
Expand Down
5 changes: 3 additions & 2 deletions pkg/flink/app/application.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package app

import (
"crypto/tls"
"sync"
"time"

Expand Down Expand Up @@ -47,7 +48,7 @@ func synchronizedTokenRefresh(tokenRefreshFunc func() error) func() error {
}
}

func StartApp(gatewayClient ccloudv2.GatewayClientInterface, tokenRefreshFunc func() error, appOptions types.ApplicationOptions, reportUsageFunc func()) error {
func StartApp(gatewayClient ccloudv2.GatewayClientInterface, tokenRefreshFunc func() error, appOptions types.ApplicationOptions, reportUsageFunc func(), tlsClientConfig *tls.Config) error {
synchronizedTokenRefreshFunc := synchronizedTokenRefresh(tokenRefreshFunc)
getAuthToken := func() string {
if authErr := synchronizedTokenRefreshFunc(); authErr != nil {
Expand All @@ -70,7 +71,7 @@ func StartApp(gatewayClient ccloudv2.GatewayClientInterface, tokenRefreshFunc fu

// Instantiate LSP
handlerCh := make(chan *jsonrpc2.Request) // This is the channel used for the messages received by the language to be passed through to the input controller
lspClient, _, err := lsp.NewInitializedLspClient(getAuthToken, appOptions.GetLSPBaseUrl(), appOptions.GetOrganizationId(), appOptions.GetEnvironmentId(), handlerCh)
lspClient, _, err := lsp.NewInitializedLspClient(getAuthToken, appOptions.GetLSPBaseUrl(), appOptions.GetOrganizationId(), appOptions.GetEnvironmentId(), tlsClientConfig, handlerCh)
if err != nil {
log.CliLogger.Errorf("Failed to connect to the language service. Check your network."+
" If you're using private networking, you might still be able to submit queries. If that's the case and you"+
Expand Down
4 changes: 3 additions & 1 deletion pkg/flink/internal/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package store

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"reflect"
Expand Down Expand Up @@ -60,7 +61,8 @@ func TestStoreProcessLocalStatement(t *testing.T) {
stores := make([]types.StoreInterface, 2)

// Cloud store
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken")
tlsClientConfig := &tls.Config{}
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig)
mockAppController := mock.NewMockApplicationControllerInterface(gomock.NewController(t))
appOptions := types.ApplicationOptions{
OrganizationId: "orgId",
Expand Down
10 changes: 7 additions & 3 deletions pkg/flink/internal/store/store_utils_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package store

import (
"crypto/tls"
"fmt"
"testing"

Expand Down Expand Up @@ -38,7 +39,8 @@ func TestRemoveStatementTerminator(t *testing.T) {

func TestProcessSetStatement(t *testing.T) {
// Create a new store
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken")
tlsClientConfig := &tls.Config{}
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig)
appOptions := &types.ApplicationOptions{
Cloud: true,
EnvironmentName: "env-123",
Expand Down Expand Up @@ -133,7 +135,8 @@ func TestProcessSetStatement(t *testing.T) {

func TestProcessResetStatement(t *testing.T) {
// Create a new store
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken")
tlsClientConfig := &tls.Config{}
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig)
appOptions := types.ApplicationOptions{
Cloud: true,
OrganizationId: "orgId",
Expand Down Expand Up @@ -210,7 +213,8 @@ func TestProcessResetStatement(t *testing.T) {

func TestProcessUseStatement(t *testing.T) {
// Create a new store
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken")
tlsClientConfig := &tls.Config{}
client := ccloudv2.NewFlinkGatewayClient("url", "userAgent", false, "authToken", tlsClientConfig)
appOptions := types.ApplicationOptions{
OrganizationId: "orgId",
EnvironmentName: "envName",
Expand Down
52 changes: 29 additions & 23 deletions pkg/flink/lsp/lsp_completer_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package lsp

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
Expand All @@ -17,13 +18,14 @@ import (
type WebsocketLSPClient struct {
sync.Mutex

baseUrl string
getAuthToken func() string
organizationId string
environmentId string
handlerCh chan *jsonrpc2.Request
conn *jsonrpc2.Conn
lspClient LspInterface
baseUrl string
getAuthToken func() string
organizationId string
environmentId string
handlerCh chan *jsonrpc2.Request
conn *jsonrpc2.Conn
lspClient LspInterface
tlsClientConfig *tls.Config
}

func (w *WebsocketLSPClient) Initialize() (*lsp.InitializeResult, error) {
Expand Down Expand Up @@ -73,7 +75,7 @@ func (w *WebsocketLSPClient) refreshWebsocketConnection() {
}

// we only update client and conn if there was no error, otherwise we leave them as is
if lspClient, conn, err := NewLSPConnection(w.baseUrl, w.getAuthToken(), w.organizationId, w.environmentId, w.handlerCh); err == nil {
if lspClient, conn, err := NewLSPConnection(w.baseUrl, w.getAuthToken(), w.organizationId, w.environmentId, w.tlsClientConfig, w.handlerCh); err == nil {
w.lspClient = lspClient
w.conn = conn
_, err := InitLspClient(w.lspClient)
Expand All @@ -86,8 +88,8 @@ func (w *WebsocketLSPClient) refreshWebsocketConnection() {
}
}

func NewInitializedLspClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, handlerCh chan *jsonrpc2.Request) (LspInterface, string, error) {
client, _, err := NewLSPClient(getAuthToken, baseUrl, organizationId, environmentId, handlerCh)
func NewInitializedLspClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, tlsClientConfig *tls.Config, handlerCh chan *jsonrpc2.Request) (LspInterface, string, error) {
client, _, err := NewLSPClient(getAuthToken, baseUrl, organizationId, environmentId, tlsClientConfig, handlerCh)
if err != nil {
return nil, "", err
}
Expand Down Expand Up @@ -116,27 +118,28 @@ func InitLspClient(client LspInterface) (string, error) {
return string(docUri), nil
}

func NewLSPClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, handlerCh chan *jsonrpc2.Request) (*WebsocketLSPClient, *jsonrpc2.Conn, error) {
lspClient, conn, err := NewLSPConnection(baseUrl, getAuthToken(), organizationId, environmentId, handlerCh)
func NewLSPClient(getAuthToken func() string, baseUrl, organizationId, environmentId string, tlsClientConfig *tls.Config, handlerCh chan *jsonrpc2.Request) (*WebsocketLSPClient, *jsonrpc2.Conn, error) {
lspClient, conn, err := NewLSPConnection(baseUrl, getAuthToken(), organizationId, environmentId, tlsClientConfig, handlerCh)
if err != nil {
return nil, conn, err
}

websocketClient := &WebsocketLSPClient{
baseUrl: baseUrl,
getAuthToken: getAuthToken,
organizationId: organizationId,
environmentId: environmentId,
handlerCh: handlerCh,
lspClient: lspClient,
conn: conn,
baseUrl: baseUrl,
getAuthToken: getAuthToken,
organizationId: organizationId,
environmentId: environmentId,
handlerCh: handlerCh,
lspClient: lspClient,
conn: conn,
tlsClientConfig: tlsClientConfig,
}

return websocketClient, websocketClient.conn, nil
}

func NewLSPConnection(baseUrl, authToken, organizationId, environmentId string, handlerCh chan *jsonrpc2.Request) (*LSPClient, *jsonrpc2.Conn, error) {
stream, err := NewWSObjectStream(baseUrl, authToken, organizationId, environmentId)
func NewLSPConnection(baseUrl, authToken, organizationId, environmentId string, tlsClientConfig *tls.Config, handlerCh chan *jsonrpc2.Request) (*LSPClient, *jsonrpc2.Conn, error) {
stream, err := NewWSObjectStream(baseUrl, authToken, organizationId, environmentId, tlsClientConfig)
if err != nil {
log.CliLogger.Debugf("Error dialing websocket: %v\n", err)
return nil, nil, err
Expand All @@ -155,12 +158,15 @@ func NewLSPConnection(baseUrl, authToken, organizationId, environmentId string,
return lspClient, conn, nil
}

func NewWSObjectStream(socketUrl, authToken, organizationId, environmentId string) (jsonrpc2.ObjectStream, error) {
func NewWSObjectStream(socketUrl, authToken, organizationId, environmentId string, tlsClientConfig *tls.Config) (jsonrpc2.ObjectStream, error) {
requestHeaders := http.Header{}
requestHeaders.Add("Authorization", fmt.Sprintf("Bearer %s", authToken))
requestHeaders.Add("Organization-ID", organizationId)
requestHeaders.Add("Environment-ID", environmentId)
conn, _, err := websocket.DefaultDialer.Dial(socketUrl, requestHeaders)
dialer := &websocket.Dialer{
TLSClientConfig: tlsClientConfig,
}
conn, _, err := dialer.Dial(socketUrl, requestHeaders)
if err != nil {
return nil, err
}
Expand Down
Loading