From 4b68de3693e25642acd847d0c8ac393982d00c9b Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 28 May 2024 12:22:16 -0600 Subject: [PATCH] feat: invalidate cache on failed IP lookup (#812) In #768, we removed an implicit behavior of the Go Connector. If a dial attempt requests a non-existent IP type (e.g., client asks for public IP on a private IP only instance), the Connector would invalidate the cache. But with the cleanup PR, we removed that implicit behavior. In some cases, it might be useful to have this behavior. For example, if a caller starts the Go Connector and tries to connect to a public IP and then later configures public IP, there is no need for a restart. We made that change in the AlloyDB Go Connector mostly because some internal tests depend on that behavior. See GoogleCloudPlatform/alloydb-go-connector#555. Fixes #780 --- dialer.go | 38 +++++++---- dialer_test.go | 122 ++++++++++++++++++++++------------ internal/cloudsql/instance.go | 18 +++++ internal/cloudsql/refresh.go | 11 +-- 4 files changed, 126 insertions(+), 63 deletions(-) diff --git a/dialer.go b/dialer.go index c93dbade..fc6301a3 100644 --- a/dialer.go +++ b/dialer.go @@ -275,12 +275,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn c := d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN) ci, err := c.ConnectionInfo(ctx) if err != nil { - d.lock.Lock() - defer d.lock.Unlock() - d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String()) - // Stop all background refreshes - c.Close() - delete(d.cache, cn) + d.removeCached(ctx, cn, c, err) endInfo(err) return nil, err } @@ -297,12 +292,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn // Block on refreshed connection info ci, err = c.ConnectionInfo(ctx) if err != nil { - d.lock.Lock() - defer d.lock.Unlock() - d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String()) - // Stop all background refreshes - c.Close() - delete(d.cache, cn) + d.removeCached(ctx, cn, c, err) return nil, err } } @@ -312,6 +302,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn defer func() { connectEnd(err) }() addr, err := ci.Addr(cfg.ipType) if err != nil { + d.removeCached(ctx, cn, c, err) return nil, err } addr = net.JoinHostPort(addr, serverProxyPort) @@ -359,10 +350,31 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn }), nil } +// removeCached stops all background refreshes and deletes the connection +// info cache from the map of caches. +func (d *Dialer) removeCached( + ctx context.Context, + i instance.ConnName, c connectionInfoCache, err error, +) { + d.logger.Debugf( + ctx, + "[%v] Removing connection info from cache: %v", + i.String(), + err, + ) + d.lock.Lock() + defer d.lock.Unlock() + c.Close() + delete(d.cache, i) +} + // validClientCert checks that the ephemeral client certificate retrieved from // the cache is unexpired. The time comparisons strip the monotonic clock value // to ensure an accurate result, even after laptop sleep. -func validClientCert(ctx context.Context, cn instance.ConnName, l debug.ContextLogger, expiration time.Time) bool { +func validClientCert( + ctx context.Context, cn instance.ConnName, + l debug.ContextLogger, expiration time.Time, +) bool { // Use UTC() to strip monotonic clock value to guard against inaccurate // comparisons, especially after laptop sleep. // See the comments on the monotonic clock in the Go documentation for diff --git a/dialer_test.go b/dialer_test.go index 051e9b92..76850489 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -16,6 +16,8 @@ package cloudsqlconn import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -622,35 +624,71 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) { // Populate instance map with connection info cache that will always fail // This allows the test to verify the error case path invoking close. badInstanceConnectionName := "doesntexist:us-central1:doesntexist" - badCN, _ := instance.ParseConnName(badInstanceConnectionName) - spy := &spyConnectionInfoCache{ - connectInfoCalls: []struct { - info cloudsql.ConnectionInfo - err error - }{{ - err: errors.New("connect info failed"), - }}, + tcs := []struct { + desc string + icn string + resp connectionInfoResp + opts []DialOption + }{ + { + desc: "dialing a bad instance URI", + icn: badInstanceConnectionName, + resp: connectionInfoResp{ + err: errors.New("connect info failed"), + }, + }, + { + desc: "specifying an invalid IP type", + icn: "myproject:myregion:myinstance", + resp: connectionInfoResp{ + info: cloudsql.NewConnectionInfo( + instance.ConnName{}, + "", + map[string]string{ + // no public IP + cloudsql.PrivateIP: "10.0.0.1", + }, + nil, + tls.Certificate{Leaf: &x509.Certificate{ + NotAfter: time.Now().Add(time.Hour), + }}, + ), + }, + opts: []DialOption{WithPublicIP()}, + }, } - d.cache[badCN] = monitoredCache{connectionInfoCache: spy} - _, err = d.Dial(context.Background(), badInstanceConnectionName) - if err == nil { - t.Fatal("expected Dial to return error") - } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Manually populate the internal cache with a spy + inst, _ := instance.ParseConnName(tc.icn) + spy := &spyConnectionInfoCache{ + connectInfoCalls: []connectionInfoResp{tc.resp}, + } + d.cache[inst] = monitoredCache{ + connectionInfoCache: spy, + } - // Verify that the connection info cache was closed (to prevent - // further failed refresh operations) - if got, want := spy.CloseWasCalled(), true; got != want { - t.Fatal("Close was not called") - } + _, err = d.Dial(context.Background(), tc.icn, tc.opts...) + if err == nil { + t.Fatal("expected Dial to return error") + } + // Verify that the connection info cache was closed (to prevent + // further failed refresh operations) + if got, want := spy.closeWasCalled(), true; got != want { + t.Fatal("Close was not called") + } - // Now verify that bad connection name has been deleted from map. - d.lock.RLock() - _, ok := d.cache[badCN] - d.lock.RUnlock() - if ok { - t.Fatal("bad instance was not removed from the cache") + // Now verify that bad connection name has been deleted from map. + d.lock.RLock() + _, ok := d.cache[inst] + d.lock.RUnlock() + if ok { + t.Fatal("connection info was not removed from cache") + } + }) } + } func TestDialRefreshesExpiredCertificates(t *testing.T) { @@ -665,10 +703,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { icn := "project:region:instance" cn, _ := instance.ParseConnName(icn) spy := &spyConnectionInfoCache{ - connectInfoCalls: []struct { - info cloudsql.ConnectionInfo - err error - }{ + connectInfoCalls: []connectionInfoResp{ // First call returns expired certificate { // Certificate expired 10 hours ago. @@ -690,13 +725,13 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { } // Verify that the cache was refreshed - if got, want := spy.ForceRefreshWasCalled(), true; got != want { + if got, want := spy.forceRefreshWasCalled(), true; got != want { t.Fatal("ForceRefresh was not called") } // Verify that the connection info cache was closed (to prevent // further failed refresh operations) - if got, want := spy.CloseWasCalled(), true; got != want { + if got, want := spy.closeWasCalled(), true; got != want { t.Fatal("Close was not called") } @@ -710,15 +745,18 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { } +type connectionInfoResp struct { + info cloudsql.ConnectionInfo + err error +} + type spyConnectionInfoCache struct { mu sync.Mutex connectInfoIndex int - connectInfoCalls []struct { - info cloudsql.ConnectionInfo - err error - } - closeWasCalled bool - forceRefreshWasCalled bool + connectInfoCalls []connectionInfoResp + + closed bool + forceRefreshed bool // embed interface to avoid having to implement irrelevant methods connectionInfoCache } @@ -736,7 +774,7 @@ func (s *spyConnectionInfoCache) ConnectionInfo( func (s *spyConnectionInfoCache) ForceRefresh() { s.mu.Lock() defer s.mu.Unlock() - s.forceRefreshWasCalled = true + s.forceRefreshed = true } func (s *spyConnectionInfoCache) UpdateRefresh(*bool) {} @@ -744,20 +782,20 @@ func (s *spyConnectionInfoCache) UpdateRefresh(*bool) {} func (s *spyConnectionInfoCache) Close() error { s.mu.Lock() defer s.mu.Unlock() - s.closeWasCalled = true + s.closed = true return nil } -func (s *spyConnectionInfoCache) CloseWasCalled() bool { +func (s *spyConnectionInfoCache) closeWasCalled() bool { s.mu.Lock() defer s.mu.Unlock() - return s.closeWasCalled + return s.closed } -func (s *spyConnectionInfoCache) ForceRefreshWasCalled() bool { +func (s *spyConnectionInfoCache) forceRefreshWasCalled() bool { s.mu.Lock() defer s.mu.Unlock() - return s.forceRefreshWasCalled + return s.forceRefreshed } func TestDialerSupportsOneOffDialFunction(t *testing.T) { diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index bebce43a..2e0b5d02 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -178,6 +178,24 @@ type ConnectionInfo struct { addrs map[string]string } +// NewConnectionInfo initializes a ConnectionInfo struct. +func NewConnectionInfo( + cn instance.ConnName, + version string, + ipAddrs map[string]string, + serverCaCert *x509.Certificate, + clientCert tls.Certificate, +) ConnectionInfo { + return ConnectionInfo{ + addrs: ipAddrs, + ServerCaCert: serverCaCert, + ClientCertificate: clientCert, + Expiration: clientCert.Leaf.NotAfter, + DBVersion: version, + ConnectionName: cn, + } +} + // Addr returns the IP address or DNS name for the given IP type. func (c ConnectionInfo) Addr(ipType string) (string, error) { var ( diff --git a/internal/cloudsql/refresh.go b/internal/cloudsql/refresh.go index f13874f2..2448c060 100644 --- a/internal/cloudsql/refresh.go +++ b/internal/cloudsql/refresh.go @@ -340,14 +340,9 @@ func (r refresher) ConnectionInfo( return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err()) } - return ConnectionInfo{ - addrs: md.ipAddrs, - ServerCaCert: md.serverCaCert, - ClientCertificate: ec, - Expiration: ec.Leaf.NotAfter, - DBVersion: md.version, - ConnectionName: cn, - }, nil + return NewConnectionInfo( + cn, md.version, md.ipAddrs, md.serverCaCert, ec, + ), nil } // supportsAutoIAMAuthN checks that the engine support automatic IAM authn. If