diff --git a/README.md b/README.md index ef77c41b..3837c45d 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,8 @@ func connect() { // ... etc } ``` -### Using DNS to identify an instance + +### Using DNS domain names to identify instances The connector can be configured to use DNS to look up an instance. This would allow you to configure your application to connect to a database instance, and @@ -292,6 +293,40 @@ func connect() { } ``` +### Automatic fail-over using DNS domain names + +When the connector is configured using a domain name, the connector will +periodically check if the DNS record for an instance changes. When the connector +detects that the domain name refers to a different instance, the connector will +close all open connections to the old instance. Subsequent connection attempts +will be directed to the new instance. + +For example: suppose application is configured to connect using the +domain name `prod-db.mycompany.example.com`. Initially the corporate DNS +zone has a TXT record with the value `my-project:region:my-instance`. The +application establishes connections to the `my-project:region:my-instance` +Cloud SQL instance. + +Then, to reconfigure the application to use a different database +instance, change the value of the `prod-db.mycompany.example.com` DNS record +from `my-project:region:my-instance` to `my-project:other-region:my-instance-2` + +The connector inside the application detects the change to this +DNS record. Now, when the application connects to its database using the +domain name `prod-db.mycompany.example.com`, it will connect to the +`my-project:other-region:my-instance-2` Cloud SQL instance. + +The connector will automatically close all existing connections to +`my-project:region:my-instance`. This will force the connection pools to +establish new connections. Also, it may cause database queries in progress +to fail. + +The connector will poll for changes to the DNS name every 30 seconds by default. +You may configure the frequency of the connections using the option +`WithFailoverPeriod(d time.Duration)`. When this is set to 0, the connector will +disable polling and only check if the DNS record changed when it is +creating a new connection. + ### Using Options diff --git a/dialer.go b/dialer.go index 8a62a997..ba380eed 100644 --- a/dialer.go +++ b/dialer.go @@ -110,12 +110,11 @@ type connectionInfoCache interface { io.Closer } -// monitoredCache is a wrapper around a connectionInfoCache that tracks the -// number of connections to the associated instance. -type monitoredCache struct { - openConns *uint64 - - connectionInfoCache +type cacheKey struct { + domainName string + project string + region string + name string } // A Dialer is used to create connections to Cloud SQL instances. @@ -123,7 +122,7 @@ type monitoredCache struct { // Use NewDialer to initialize a Dialer. type Dialer struct { lock sync.RWMutex - cache map[instance.ConnName]monitoredCache + cache map[cacheKey]*monitoredCache keyGenerator *keyGenerator refreshTimeout time.Duration // closed reports if the dialer has been closed. @@ -155,7 +154,8 @@ type Dialer struct { iamTokenSource oauth2.TokenSource // resolver converts instance names into DNS names. - resolver instance.ConnectionNameResolver + resolver instance.ConnectionNameResolver + failoverPeriod time.Duration } var ( @@ -179,6 +179,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { logger: nullLogger{}, useragents: []string{userAgent}, serviceUniverse: "googleapis.com", + failoverPeriod: cloudsql.FailoverPeriod, } for _, opt := range opts { opt(cfg) @@ -192,6 +193,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN { return nil, errUseTokenSource } + // Add this to the end to make sure it's not overridden cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " "))) @@ -263,7 +265,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { d := &Dialer{ closed: make(chan struct{}), - cache: make(map[instance.ConnName]monitoredCache), + cache: make(map[cacheKey]*monitoredCache), lazyRefresh: cfg.lazyRefresh, keyGenerator: g, refreshTimeout: cfg.refreshTimeout, @@ -274,7 +276,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { iamTokenSource: cfg.iamLoginTokenSource, dialFunc: cfg.dialFunc, resolver: r, + failoverPeriod: cfg.failoverPeriod, } + return d, nil } @@ -301,6 +305,10 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn if err != nil { return nil, err } + // Log if resolver changed the instance name input string. + if cn.String() != icn { + d.logger.Debugf(ctx, "resolved instance %s to %s", icn, cn) + } cfg := d.defaultDialConfig for _, opt := range opts { @@ -380,15 +388,24 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn latency := time.Since(startTime).Milliseconds() go func() { - n := atomic.AddUint64(c.openConns, 1) + n := atomic.AddUint64(c.openConnsCount, 1) trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String()) trace.RecordDialLatency(ctx, icn, d.dialerID, latency) }() - return newInstrumentedConn(tlsConn, func() { - n := atomic.AddUint64(c.openConns, ^uint64(0)) + iConn := newInstrumentedConn(tlsConn, func() { + n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String()) - }, d.dialerID, cn.String()), nil + }, d.dialerID, cn.String()) + + // If this connection was opened using a Domain Name, then store it for later + // in case it needs to be forcibly closed. + if cn.HasDomainName() { + c.mu.Lock() + c.openConns = append(c.openConns, iConn) + c.mu.Unlock() + } + return iConn, nil } // removeCached stops all background refreshes and deletes the connection @@ -406,7 +423,7 @@ func (d *Dialer) removeCached( d.lock.Lock() defer d.lock.Unlock() c.Close() - delete(d.cache, i) + delete(d.cache, createKey(i)) } // validClientCert checks that the ephemeral client certificate retrieved from @@ -448,7 +465,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) } ci, err := c.ConnectionInfo(ctx) if err != nil { - d.removeCached(ctx, cn, c, err) + d.removeCached(ctx, cn, c.connectionInfoCache, err) return "", err } return ci.DBVersion, nil @@ -472,7 +489,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err } _, err = c.ConnectionInfo(ctx) if err != nil { - d.removeCached(ctx, cn, c, err) + d.removeCached(ctx, cn, c.connectionInfoCache, err) } return err } @@ -493,6 +510,8 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str type instrumentedConn struct { net.Conn closeFunc func() + mu sync.RWMutex + closed bool dialerID string connName string } @@ -517,9 +536,19 @@ func (i *instrumentedConn) Write(b []byte) (int, error) { return bytesWritten, err } +// isClosed returns true if this connection is closing or is already closed. +func (i *instrumentedConn) isClosed() bool { + i.mu.RLock() + defer i.mu.RUnlock() + return i.closed +} + // Close delegates to the underlying net.Conn interface and reports the close // to the provided closeFunc only when Close returns no error. func (i *instrumentedConn) Close() error { + i.mu.Lock() + defer i.mu.Unlock() + i.closed = true err := i.Conn.Close() if err != nil { return err @@ -546,55 +575,81 @@ func (d *Dialer) Close() error { return nil } +// createKey creates a key for the cache from an instance.ConnName. +// An instance.ConnName uniquely identifies a connection using +// project:region:instance + domainName. However, in the dialer cache, +// we want to to identify entries either by project:region:instance, or +// by domainName, but not the combination of the two. +func createKey(cn instance.ConnName) cacheKey { + if cn.HasDomainName() { + return cacheKey{domainName: cn.DomainName()} + } + return cacheKey{ + name: cn.Name(), + project: cn.Project(), + region: cn.Region(), + } +} + // connectionInfoCache is a helper function for returning the appropriate // connection info Cache in a threadsafe way. It will create a new cache, // modify the existing one, or leave it unchanged as needed. func (d *Dialer) connectionInfoCache( ctx context.Context, cn instance.ConnName, useIAMAuthN *bool, -) (monitoredCache, error) { +) (*monitoredCache, error) { + k := createKey(cn) + d.lock.RLock() - c, ok := d.cache[cn] + c, ok := d.cache[k] d.lock.RUnlock() - if !ok { - d.lock.Lock() - defer d.lock.Unlock() - // Recheck to ensure instance wasn't created or changed between locks - c, ok = d.cache[cn] - if !ok { - var useIAMAuthNDial bool - if useIAMAuthN != nil { - useIAMAuthNDial = *useIAMAuthN - } - d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String()) - k, err := d.keyGenerator.rsaKey() - if err != nil { - return monitoredCache{}, err - } - var cache connectionInfoCache - if d.lazyRefresh { - cache = cloudsql.NewLazyRefreshCache( - cn, - d.logger, - d.sqladmin, k, - d.refreshTimeout, d.iamTokenSource, - d.dialerID, useIAMAuthNDial, - ) - } else { - cache = cloudsql.NewRefreshAheadCache( - cn, - d.logger, - d.sqladmin, k, - d.refreshTimeout, d.iamTokenSource, - d.dialerID, useIAMAuthNDial, - ) - } - var count uint64 - c = monitoredCache{openConns: &count, connectionInfoCache: cache} - d.cache[cn] = c - } + + if ok && !c.isClosed() { + c.UpdateRefresh(useIAMAuthN) + return c, nil + } + + d.lock.Lock() + defer d.lock.Unlock() + + // Recheck to ensure instance wasn't created or changed between locks + c, ok = d.cache[k] + + // c exists and is not closed + if ok && !c.isClosed() { + c.UpdateRefresh(useIAMAuthN) + return c, nil } - c.UpdateRefresh(useIAMAuthN) + // Create a new instance of monitoredCache + var useIAMAuthNDial bool + if useIAMAuthN != nil { + useIAMAuthNDial = *useIAMAuthN + } + d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String()) + rsaKey, err := d.keyGenerator.rsaKey() + if err != nil { + return nil, err + } + var cache connectionInfoCache + if d.lazyRefresh { + cache = cloudsql.NewLazyRefreshCache( + cn, + d.logger, + d.sqladmin, rsaKey, + d.refreshTimeout, d.iamTokenSource, + d.dialerID, useIAMAuthNDial, + ) + } else { + cache = cloudsql.NewRefreshAheadCache( + cn, + d.logger, + d.sqladmin, rsaKey, + d.refreshTimeout, d.iamTokenSource, + d.dialerID, useIAMAuthNDial, + ) + } + c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger) + d.cache[k] = c return c, nil } diff --git a/dialer_test.go b/dialer_test.go index 2e640af4..07f18bdd 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "testing" "time" @@ -476,9 +477,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) { spy := &spyConnectionInfoCache{ connectInfoCalls: []connectionInfoResp{tc.resp}, } - d.cache[inst] = monitoredCache{ - connectionInfoCache: spy, - } + d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil) _, err = d.EngineVersion(context.Background(), tc.icn) if err == nil { @@ -492,7 +491,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) { // Now verify that bad connection name has been deleted from map. d.lock.RLock() - _, ok := d.cache[inst] + _, ok := d.cache[createKey(inst)] d.lock.RUnlock() if ok { t.Fatal("connection info was not removed from cache") @@ -626,9 +625,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) { spy := &spyConnectionInfoCache{ connectInfoCalls: []connectionInfoResp{tc.resp}, } - d.cache[inst] = monitoredCache{ - connectionInfoCache: spy, - } + d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil) err = d.Warmup(context.Background(), tc.icn, tc.opts...) if err == nil { @@ -642,7 +639,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) { // Now verify that bad connection name has been deleted from map. d.lock.RLock() - _, ok := d.cache[inst] + _, ok := d.cache[createKey(inst)] d.lock.RUnlock() if ok { t.Fatal("connection info was not removed from cache") @@ -802,9 +799,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) { spy := &spyConnectionInfoCache{ connectInfoCalls: []connectionInfoResp{tc.resp}, } - d.cache[inst] = monitoredCache{ - connectionInfoCache: spy, - } + d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil) _, err = d.Dial(context.Background(), tc.icn, tc.opts...) if err == nil { @@ -818,7 +813,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) { // Now verify that bad connection name has been deleted from map. d.lock.RLock() - _, ok := d.cache[inst] + _, ok := d.cache[createKey(inst)] d.lock.RUnlock() if ok { t.Fatal("connection info was not removed from cache") @@ -854,7 +849,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { }, }, } - d.cache[cn] = monitoredCache{connectionInfoCache: spy} + d.cache[createKey(cn)] = newMonitoredCache(nil, spy, cn, 0, nil, nil) _, err = d.Dial(context.Background(), icn) if !errors.Is(err, sentinel) { @@ -874,7 +869,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { // Now verify that bad connection name has been deleted from map. d.lock.RLock() - _, ok := d.cache[cn] + _, ok := d.cache[createKey(cn)] d.lock.RUnlock() if ok { t.Fatal("bad instance was not removed from the cache") @@ -1015,7 +1010,7 @@ func TestDialerInitializesLazyCache(t *testing.T) { t.Fatal(err) } - c, ok := d.cache[cn] + c, ok := d.cache[createKey(cn)] if !ok { t.Fatal("cache was not populated") } @@ -1028,16 +1023,13 @@ func TestDialerInitializesLazyCache(t *testing.T) { } type fakeResolver struct { - domainName string - instanceName instance.ConnName + entries map[string]instance.ConnName } func (r *fakeResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) { - // For TestDialerSuccessfullyDialsDnsTxtRecord - if name == r.domainName { - return r.instanceName, nil + if val, ok := r.entries[name]; ok { + return val, nil } - // TestDialerFailsDnsTxtRecordMissing return instance.ConnName{}, fmt.Errorf("no resolution for %q", name) } @@ -1045,18 +1037,23 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { inst := mock.NewFakeCSQLInstance( "my-project", "my-region", "my-instance", ) - wantName, _ := instance.ParseConnName("my-project:my-region:my-instance") + wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") + wantName2, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db2.example.com") + // This will create 2 separate connectionInfoCache entries, one for + // each DNS name. d := setupDialer(t, setupConfig{ testInstance: inst, reqs: []*mock.Request{ - mock.InstanceGetSuccess(inst, 1), - mock.CreateEphemeralSuccess(inst, 1), + mock.InstanceGetSuccess(inst, 2), + mock.CreateEphemeralSuccess(inst, 2), }, dialerOptions: []Option{ WithTokenSource(mock.EmptyTokenSource{}), WithResolver(&fakeResolver{ - domainName: "db.example.com", - instanceName: wantName, + entries: map[string]instance.ConnName{ + "db.example.com": wantName, + "db2.example.com": wantName2, + }, }), }, }) @@ -1065,6 +1062,10 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { context.Background(), t, d, "db.example.com", ) + testSuccessfulDial( + context.Background(), t, d, + "db2.example.com", + ) } func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { @@ -1085,3 +1086,84 @@ func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { t.Fatalf("want = %v, got = %v", wantMsg, err) } } + +type changingResolver struct { + stage *int32 +} + +func (r *changingResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) { + // For TestDialerFailoverOnInstanceChange + if name == "update.example.com" { + if atomic.LoadInt32(r.stage) == 0 { + return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + } + return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com") + } + // TestDialerFailsDnsSrvRecordMissing + return instance.ConnName{}, fmt.Errorf("no resolution for %q", name) +} + +func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) { + // At first, the resolver will resolve + // update.example.com to "my-instance" + // Then, the resolver will resolve the same domain name to + // "my-instance2". + // This shows that on every call to Dial(), the dialer will resolve the + // SRV record and connect to the correct instance. + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + ) + inst2 := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance2", + ) + r := &changingResolver{ + stage: new(int32), + } + + d := setupDialer(t, setupConfig{ + skipServer: true, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + mock.InstanceGetSuccess(inst2, 1), + mock.CreateEphemeralSuccess(inst2, 1), + }, + dialerOptions: []Option{ + WithFailoverPeriod(10 * time.Millisecond), + WithResolver(r), + WithTokenSource(mock.EmptyTokenSource{}), + }, + }) + + // Start the proxy for instance 1 + stop1 := mock.StartServerProxy(t, inst) + t.Cleanup(func() { + stop1() + }) + + testSuccessfulDial( + context.Background(), t, d, + "update.example.com", + ) + stop1() + atomic.StoreInt32(r.stage, 1) + + time.Sleep(1 * time.Second) + instCn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + c, _ := d.cache[createKey(instCn)] + if !c.isClosed() { + t.Fatal("Expected monitoredCache to be closed after domain name changed. monitoredCache was not closed.") + } + + // Start the proxy for instance 2 + stop2 := mock.StartServerProxy(t, inst2) + t.Cleanup(func() { + stop2() + }) + + testSucessfulDialWithInstanceName( + context.Background(), t, d, + "update.example.com", "my-instance2", + ) + +} diff --git a/instance/conn_name.go b/instance/conn_name.go index 2dd3de73..01f72c7c 100644 --- a/instance/conn_name.go +++ b/instance/conn_name.go @@ -32,12 +32,16 @@ var ( // ConnName represents the "instance connection name", in the format // "project:region:name". type ConnName struct { - project string - region string - name string + project string + region string + name string + domainName string } func (c *ConnName) String() string { + if c.domainName != "" { + return fmt.Sprintf("%s -> %s:%s:%s", c.domainName, c.project, c.region, c.name) + } return fmt.Sprintf("%s:%s:%s", c.project, c.region, c.name) } @@ -56,8 +60,24 @@ func (c *ConnName) Name() string { return c.name } +// DomainName returns the domain name for this instance +func (c *ConnName) DomainName() string { + return c.domainName +} + +// HasDomainName returns the Cloud SQL domain name +func (c *ConnName) HasDomainName() bool { + return c.domainName != "" +} + // ParseConnName initializes a new ConnName struct. func ParseConnName(cn string) (ConnName, error) { + return ParseConnNameWithDomainName(cn, "") +} + +// ParseConnNameWithDomainName initializes a new ConnName struct, +// also setting the domain name. +func ParseConnNameWithDomainName(cn string, dn string) (ConnName, error) { b := []byte(cn) m := connNameRegex.FindSubmatch(b) if m == nil { @@ -69,9 +89,10 @@ func ParseConnName(cn string) (ConnName, error) { } c := ConnName{ - project: string(m[1]), - region: string(m[3]), - name: string(m[4]), + project: string(m[1]), + region: string(m[3]), + name: string(m[4]), + domainName: dn, } return c, nil } diff --git a/instance/conn_name_test.go b/instance/conn_name_test.go index 315dec4d..e07f759a 100644 --- a/instance/conn_name_test.go +++ b/instance/conn_name_test.go @@ -23,11 +23,11 @@ func TestParseConnName(t *testing.T) { }{ { "project:region:instance", - ConnName{"project", "region", "instance"}, + ConnName{project: "project", region: "region", name: "instance"}, }, { "google.com:project:region:instance", - ConnName{"google.com:project", "region", "instance"}, + ConnName{project: "google.com:project", region: "region", name: "instance"}, }, { "project:instance", // missing region diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index f8e44b3b..bc25e672 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -45,6 +45,11 @@ const ( // refreshInterval. RefreshTimeout = 60 * time.Second + // FailoverPeriod is the frequency with which the dialer will check + // if the DNS record has changed for connections configured using + // a DNS name. + FailoverPeriod = 30 * time.Second + // refreshBurst is the initial burst allowed by the rate limiter. refreshBurst = 2 ) diff --git a/monitored_cache.go b/monitored_cache.go new file mode 100644 index 00000000..b3929b53 --- /dev/null +++ b/monitored_cache.go @@ -0,0 +1,146 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlconn + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "cloud.google.com/go/cloudsqlconn/debug" + "cloud.google.com/go/cloudsqlconn/instance" +) + +// monitoredCache is a wrapper around a connectionInfoCache that tracks the +// number of connections to the associated instance. +type monitoredCache struct { + openConnsCount *uint64 + cn instance.ConnName + resolver instance.ConnectionNameResolver + logger debug.ContextLogger + + // domainNameTicker periodically checks any domain names to see if they + // changed. + domainNameTicker *time.Ticker + closedCh chan struct{} + + mu sync.Mutex + openConns []*instrumentedConn + closed bool + + connectionInfoCache +} + +func newMonitoredCache( + ctx context.Context, + cache connectionInfoCache, + cn instance.ConnName, + failoverPeriod time.Duration, + resolver instance.ConnectionNameResolver, + logger debug.ContextLogger) *monitoredCache { + + c := &monitoredCache{ + openConnsCount: new(uint64), + closedCh: make(chan struct{}), + cn: cn, + resolver: resolver, + logger: logger, + connectionInfoCache: cache, + } + if cn.HasDomainName() { + c.domainNameTicker = time.NewTicker(failoverPeriod) + go func() { + for { + select { + case <-c.domainNameTicker.C: + c.purgeClosedConns() + c.checkDomainName(ctx) + case <-c.closedCh: + return + } + } + }() + + } + + return c +} +func (c *monitoredCache) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func (c *monitoredCache) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } + + c.closed = true + close(c.closedCh) + + if c.domainNameTicker != nil { + c.domainNameTicker.Stop() + } + + if atomic.LoadUint64(c.openConnsCount) > 0 { + for _, socket := range c.openConns { + if !socket.isClosed() { + _ = socket.Close() // force socket closed, ok to ignore error. + } + } + atomic.StoreUint64(c.openConnsCount, 0) + } + + return c.connectionInfoCache.Close() +} + +func (c *monitoredCache) purgeClosedConns() { + c.mu.Lock() + defer c.mu.Unlock() + + var open []*instrumentedConn + for _, s := range c.openConns { + if !s.isClosed() { + open = append(open, s) + } + } + c.openConns = open +} + +func (c *monitoredCache) checkDomainName(ctx context.Context) { + if !c.cn.HasDomainName() { + return + } + newCn, err := c.resolver.Resolve(ctx, c.cn.DomainName()) + if err != nil { + // The domain name could not be resolved. + c.logger.Debugf(ctx, "domain name %s for instance %s did not resolve, "+ + "closing all connections: %v", + c.cn.DomainName(), c.cn.Name(), err) + c.Close() + } + if newCn != c.cn { + // The instance changed. + c.logger.Debugf(ctx, "domain name %s changed from %s to %s, "+ + "closing all connections.", + c.cn.DomainName(), c.cn.Name(), newCn.Name()) + c.Close() + } + +} diff --git a/monitored_cache_test.go b/monitored_cache_test.go new file mode 100644 index 00000000..0fa42a58 --- /dev/null +++ b/monitored_cache_test.go @@ -0,0 +1,180 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlconn + +import ( + "context" + "net" + "net/netip" + "sync/atomic" + "testing" + "time" + + "cloud.google.com/go/cloudsqlconn/instance" +) + +type testLog struct { + t *testing.T +} + +func (l *testLog) Debugf(_ context.Context, f string, args ...interface{}) { + l.t.Logf(f, args...) +} + +func TestMonitoredCache_purgeClosedConns(t *testing.T) { + cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") + c := newMonitoredCache(context.TODO(), + &spyConnectionInfoCache{}, + cn, + 10*time.Millisecond, + &fakeResolver{entries: map[string]instance.ConnName{"db.example.com": cn}}, + &testLog{t: t}, + ) + + // Add connections + c.mu.Lock() + c.openConns = []*instrumentedConn{ + &instrumentedConn{closed: false}, + &instrumentedConn{closed: true}, + } + c.mu.Unlock() + + // wait for the resolver to run + time.Sleep(100 * time.Millisecond) + c.mu.Lock() + if got := len(c.openConns); got != 1 { + t.Fatalf("got %d, want 1. Expected openConns to only contain open connections", got) + } + c.mu.Unlock() + +} + +func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) { + cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + r := &changingResolver{ + stage: new(int32), + } + c := newMonitoredCache(context.TODO(), + &spyConnectionInfoCache{}, + cn, + 10*time.Millisecond, + r, + &testLog{t: t}, + ) + + // Dont' change the instance yet. Check that the connection is open. + // wait for the resolver to run + time.Sleep(100 * time.Millisecond) + if c.isClosed() { + t.Fatal("got cache closed, want cache open") + } + // update the domain name + atomic.StoreInt32(r.stage, 1) + + // wait for the resolver to run + time.Sleep(100 * time.Millisecond) + if !c.isClosed() { + t.Fatal("got cache open, want cache closed") + } + +} + +func TestMonitoredCache_Close(t *testing.T) { + cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + var closeFuncCalls int32 + + r := &changingResolver{ + stage: new(int32), + } + + c := newMonitoredCache(context.TODO(), + &spyConnectionInfoCache{}, + cn, + 10*time.Millisecond, + r, + &testLog{t: t}, + ) + inc := func() { + atomic.AddInt32(&closeFuncCalls, 1) + } + + c.mu.Lock() + // set up the state as if there were 2 open connections. + c.openConns = []*instrumentedConn{ + { + closed: false, + closeFunc: inc, + Conn: &mockConn{}, + }, + { + closed: false, + closeFunc: inc, + Conn: &mockConn{}, + }, + { + closed: true, + closeFunc: inc, + Conn: &mockConn{}, + }, + } + *c.openConnsCount = 2 + c.mu.Unlock() + + c.Close() + if !c.isClosed() { + t.Fatal("got cache open, want cache closed") + } + // wait for closeFunc() to be called. + time.Sleep(100 * time.Millisecond) + if got := atomic.LoadInt32(&closeFuncCalls); got != 2 { + t.Fatalf("got %d, want 2", got) + } + +} + +type mockConn struct { +} + +func (m *mockConn) Read(_ []byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Write(_ []byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Close() error { + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { + return net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:3307")) +} + +func (m *mockConn) RemoteAddr() net.Addr { + return net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:3307")) +} + +func (m *mockConn) SetDeadline(_ time.Time) error { + return nil +} + +func (m *mockConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (m *mockConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/options.go b/options.go index c21fcd2a..a719eca9 100644 --- a/options.go +++ b/options.go @@ -54,6 +54,7 @@ type dialerConfig struct { setTokenSource bool setIAMAuthNTokenSource bool resolver instance.ConnectionNameResolver + failoverPeriod time.Duration // err tracks any dialer options that may have failed. err error } @@ -271,6 +272,16 @@ func WithDNSResolver() Option { } } +// WithFailoverPeriod will cause the connector to periodically check the SRV DNS +// records of instance configured using DNS names. By default, this is 30 +// seconds. If this is set to 0, the connector will only check for domain name +// changes when establishing a new connection. +func WithFailoverPeriod(f time.Duration) Option { + return func(d *dialerConfig) { + d.failoverPeriod = f + } +} + type debugLoggerWithoutContext struct { logger debug.Logger }