Skip to content

Commit

Permalink
feat: invalidate cache on failed IP lookup (#812)
Browse files Browse the repository at this point in the history
In #768, we removed an implicit behavior of the Go Connector. If a dial
attempt requests a non-existent IP type (e.g., client asks for public IP
on a private IP only instance), the Connector would invalidate the
cache. But with the cleanup PR, we removed that implicit behavior.

In some cases, it might be useful to have this behavior. For example, if
a caller starts the Go Connector and tries to connect to a public IP and
then later configures public IP, there is no need for a restart.

We made that change in the AlloyDB Go Connector mostly because some
internal tests depend on that behavior. See
GoogleCloudPlatform/alloydb-go-connector#555.

Fixes #780
  • Loading branch information
enocom authored May 28, 2024
1 parent 164badf commit 4b68de3
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 63 deletions.
38 changes: 25 additions & 13 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
c := d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN)
ci, err := c.ConnectionInfo(ctx)
if err != nil {
d.lock.Lock()
defer d.lock.Unlock()
d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String())
// Stop all background refreshes
c.Close()
delete(d.cache, cn)
d.removeCached(ctx, cn, c, err)
endInfo(err)
return nil, err
}
Expand All @@ -297,12 +292,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
// Block on refreshed connection info
ci, err = c.ConnectionInfo(ctx)
if err != nil {
d.lock.Lock()
defer d.lock.Unlock()
d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String())
// Stop all background refreshes
c.Close()
delete(d.cache, cn)
d.removeCached(ctx, cn, c, err)
return nil, err
}
}
Expand All @@ -312,6 +302,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
defer func() { connectEnd(err) }()
addr, err := ci.Addr(cfg.ipType)
if err != nil {
d.removeCached(ctx, cn, c, err)
return nil, err
}
addr = net.JoinHostPort(addr, serverProxyPort)
Expand Down Expand Up @@ -359,10 +350,31 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
}), nil
}

// removeCached stops all background refreshes and deletes the connection
// info cache from the map of caches.
func (d *Dialer) removeCached(
ctx context.Context,
i instance.ConnName, c connectionInfoCache, err error,
) {
d.logger.Debugf(
ctx,
"[%v] Removing connection info from cache: %v",
i.String(),
err,
)
d.lock.Lock()
defer d.lock.Unlock()
c.Close()
delete(d.cache, i)
}

// validClientCert checks that the ephemeral client certificate retrieved from
// the cache is unexpired. The time comparisons strip the monotonic clock value
// to ensure an accurate result, even after laptop sleep.
func validClientCert(ctx context.Context, cn instance.ConnName, l debug.ContextLogger, expiration time.Time) bool {
func validClientCert(
ctx context.Context, cn instance.ConnName,
l debug.ContextLogger, expiration time.Time,
) bool {
// Use UTC() to strip monotonic clock value to guard against inaccurate
// comparisons, especially after laptop sleep.
// See the comments on the monotonic clock in the Go documentation for
Expand Down
122 changes: 80 additions & 42 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package cloudsqlconn

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -622,35 +624,71 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
// Populate instance map with connection info cache that will always fail
// This allows the test to verify the error case path invoking close.
badInstanceConnectionName := "doesntexist:us-central1:doesntexist"
badCN, _ := instance.ParseConnName(badInstanceConnectionName)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
info cloudsql.ConnectionInfo
err error
}{{
err: errors.New("connect info failed"),
}},
tcs := []struct {
desc string
icn string
resp connectionInfoResp
opts []DialOption
}{
{
desc: "dialing a bad instance URI",
icn: badInstanceConnectionName,
resp: connectionInfoResp{
err: errors.New("connect info failed"),
},
},
{
desc: "specifying an invalid IP type",
icn: "myproject:myregion:myinstance",
resp: connectionInfoResp{
info: cloudsql.NewConnectionInfo(
instance.ConnName{},
"",
map[string]string{
// no public IP
cloudsql.PrivateIP: "10.0.0.1",
},
nil,
tls.Certificate{Leaf: &x509.Certificate{
NotAfter: time.Now().Add(time.Hour),
}},
),
},
opts: []DialOption{WithPublicIP()},
},
}
d.cache[badCN] = monitoredCache{connectionInfoCache: spy}

_, err = d.Dial(context.Background(), badInstanceConnectionName)
if err == nil {
t.Fatal("expected Dial to return error")
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
// Manually populate the internal cache with a spy
inst, _ := instance.ParseConnName(tc.icn)
spy := &spyConnectionInfoCache{
connectInfoCalls: []connectionInfoResp{tc.resp},
}
d.cache[inst] = monitoredCache{
connectionInfoCache: spy,
}

// Verify that the connection info cache was closed (to prevent
// further failed refresh operations)
if got, want := spy.CloseWasCalled(), true; got != want {
t.Fatal("Close was not called")
}
_, err = d.Dial(context.Background(), tc.icn, tc.opts...)
if err == nil {
t.Fatal("expected Dial to return error")
}
// Verify that the connection info cache was closed (to prevent
// further failed refresh operations)
if got, want := spy.closeWasCalled(), true; got != want {
t.Fatal("Close was not called")
}

// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.cache[badCN]
d.lock.RUnlock()
if ok {
t.Fatal("bad instance was not removed from the cache")
// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.cache[inst]
d.lock.RUnlock()
if ok {
t.Fatal("connection info was not removed from cache")
}
})
}

}

func TestDialRefreshesExpiredCertificates(t *testing.T) {
Expand All @@ -665,10 +703,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
icn := "project:region:instance"
cn, _ := instance.ParseConnName(icn)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
info cloudsql.ConnectionInfo
err error
}{
connectInfoCalls: []connectionInfoResp{
// First call returns expired certificate
{
// Certificate expired 10 hours ago.
Expand All @@ -690,13 +725,13 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
}

// Verify that the cache was refreshed
if got, want := spy.ForceRefreshWasCalled(), true; got != want {
if got, want := spy.forceRefreshWasCalled(), true; got != want {
t.Fatal("ForceRefresh was not called")
}

// Verify that the connection info cache was closed (to prevent
// further failed refresh operations)
if got, want := spy.CloseWasCalled(), true; got != want {
if got, want := spy.closeWasCalled(), true; got != want {
t.Fatal("Close was not called")
}

Expand All @@ -710,15 +745,18 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {

}

type connectionInfoResp struct {
info cloudsql.ConnectionInfo
err error
}

type spyConnectionInfoCache struct {
mu sync.Mutex
connectInfoIndex int
connectInfoCalls []struct {
info cloudsql.ConnectionInfo
err error
}
closeWasCalled bool
forceRefreshWasCalled bool
connectInfoCalls []connectionInfoResp

closed bool
forceRefreshed bool
// embed interface to avoid having to implement irrelevant methods
connectionInfoCache
}
Expand All @@ -736,28 +774,28 @@ func (s *spyConnectionInfoCache) ConnectionInfo(
func (s *spyConnectionInfoCache) ForceRefresh() {
s.mu.Lock()
defer s.mu.Unlock()
s.forceRefreshWasCalled = true
s.forceRefreshed = true
}

func (s *spyConnectionInfoCache) UpdateRefresh(*bool) {}

func (s *spyConnectionInfoCache) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
s.closeWasCalled = true
s.closed = true
return nil
}

func (s *spyConnectionInfoCache) CloseWasCalled() bool {
func (s *spyConnectionInfoCache) closeWasCalled() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.closeWasCalled
return s.closed
}

func (s *spyConnectionInfoCache) ForceRefreshWasCalled() bool {
func (s *spyConnectionInfoCache) forceRefreshWasCalled() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.forceRefreshWasCalled
return s.forceRefreshed
}

func TestDialerSupportsOneOffDialFunction(t *testing.T) {
Expand Down
18 changes: 18 additions & 0 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ type ConnectionInfo struct {
addrs map[string]string
}

// NewConnectionInfo initializes a ConnectionInfo struct.
func NewConnectionInfo(
cn instance.ConnName,
version string,
ipAddrs map[string]string,
serverCaCert *x509.Certificate,
clientCert tls.Certificate,
) ConnectionInfo {
return ConnectionInfo{
addrs: ipAddrs,
ServerCaCert: serverCaCert,
ClientCertificate: clientCert,
Expiration: clientCert.Leaf.NotAfter,
DBVersion: version,
ConnectionName: cn,
}
}

// Addr returns the IP address or DNS name for the given IP type.
func (c ConnectionInfo) Addr(ipType string) (string, error) {
var (
Expand Down
11 changes: 3 additions & 8 deletions internal/cloudsql/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,14 +340,9 @@ func (r refresher) ConnectionInfo(
return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}

return ConnectionInfo{
addrs: md.ipAddrs,
ServerCaCert: md.serverCaCert,
ClientCertificate: ec,
Expiration: ec.Leaf.NotAfter,
DBVersion: md.version,
ConnectionName: cn,
}, nil
return NewConnectionInfo(
cn, md.version, md.ipAddrs, md.serverCaCert, ec,
), nil
}

// supportsAutoIAMAuthN checks that the engine support automatic IAM authn. If
Expand Down

0 comments on commit 4b68de3

Please sign in to comment.