Skip to content

Commit

Permalink
feat: Automatically reset connection when the DNS record changes. (#868)
Browse files Browse the repository at this point in the history
When a connection is configured using a DNS domain name, the connector will poll the DNS TXT record every 30 seconds. If the value of the DNS record changes, the connector will close all connections to the old instance, and direct new connections to the updated instance.

Fixes of #842
  • Loading branch information
hessjcg authored Sep 20, 2024
1 parent baf4575 commit 4d7abd8
Show file tree
Hide file tree
Showing 9 changed files with 626 additions and 91 deletions.
37 changes: 36 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ func connect() {
// ... etc
}
```
### Using DNS to identify an instance

### Using DNS domain names to identify instances

The connector can be configured to use DNS to look up an instance. This would
allow you to configure your application to connect to a database instance, and
Expand Down Expand Up @@ -292,6 +293,40 @@ func connect() {
}
```

### Automatic fail-over using DNS domain names

When the connector is configured using a domain name, the connector will
periodically check if the DNS record for an instance changes. When the connector
detects that the domain name refers to a different instance, the connector will
close all open connections to the old instance. Subsequent connection attempts
will be directed to the new instance.

For example: suppose application is configured to connect using the
domain name `prod-db.mycompany.example.com`. Initially the corporate DNS
zone has a TXT record with the value `my-project:region:my-instance`. The
application establishes connections to the `my-project:region:my-instance`
Cloud SQL instance.

Then, to reconfigure the application to use a different database
instance, change the value of the `prod-db.mycompany.example.com` DNS record
from `my-project:region:my-instance` to `my-project:other-region:my-instance-2`

The connector inside the application detects the change to this
DNS record. Now, when the application connects to its database using the
domain name `prod-db.mycompany.example.com`, it will connect to the
`my-project:other-region:my-instance-2` Cloud SQL instance.

The connector will automatically close all existing connections to
`my-project:region:my-instance`. This will force the connection pools to
establish new connections. Also, it may cause database queries in progress
to fail.

The connector will poll for changes to the DNS name every 30 seconds by default.
You may configure the frequency of the connections using the option
`WithFailoverPeriod(d time.Duration)`. When this is set to 0, the connector will
disable polling and only check if the DNS record changed when it is
creating a new connection.


### Using Options

Expand Down
167 changes: 111 additions & 56 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,19 @@ type connectionInfoCache interface {
io.Closer
}

// monitoredCache is a wrapper around a connectionInfoCache that tracks the
// number of connections to the associated instance.
type monitoredCache struct {
openConns *uint64

connectionInfoCache
type cacheKey struct {
domainName string
project string
region string
name string
}

// A Dialer is used to create connections to Cloud SQL instances.
//
// Use NewDialer to initialize a Dialer.
type Dialer struct {
lock sync.RWMutex
cache map[instance.ConnName]monitoredCache
cache map[cacheKey]*monitoredCache
keyGenerator *keyGenerator
refreshTimeout time.Duration
// closed reports if the dialer has been closed.
Expand Down Expand Up @@ -155,7 +154,8 @@ type Dialer struct {
iamTokenSource oauth2.TokenSource

// resolver converts instance names into DNS names.
resolver instance.ConnectionNameResolver
resolver instance.ConnectionNameResolver
failoverPeriod time.Duration
}

var (
Expand All @@ -179,6 +179,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
logger: nullLogger{},
useragents: []string{userAgent},
serviceUniverse: "googleapis.com",
failoverPeriod: cloudsql.FailoverPeriod,
}
for _, opt := range opts {
opt(cfg)
Expand All @@ -192,6 +193,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
return nil, errUseTokenSource
}

// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))

Expand Down Expand Up @@ -263,7 +265,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {

d := &Dialer{
closed: make(chan struct{}),
cache: make(map[instance.ConnName]monitoredCache),
cache: make(map[cacheKey]*monitoredCache),
lazyRefresh: cfg.lazyRefresh,
keyGenerator: g,
refreshTimeout: cfg.refreshTimeout,
Expand All @@ -274,7 +276,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
iamTokenSource: cfg.iamLoginTokenSource,
dialFunc: cfg.dialFunc,
resolver: r,
failoverPeriod: cfg.failoverPeriod,
}

return d, nil
}

Expand All @@ -301,6 +305,10 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
if err != nil {
return nil, err
}
// Log if resolver changed the instance name input string.
if cn.String() != icn {
d.logger.Debugf(ctx, "resolved instance %s to %s", icn, cn)
}

cfg := d.defaultDialConfig
for _, opt := range opts {
Expand Down Expand Up @@ -380,15 +388,24 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn

latency := time.Since(startTime).Milliseconds()
go func() {
n := atomic.AddUint64(c.openConns, 1)
n := atomic.AddUint64(c.openConnsCount, 1)
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
}()

return newInstrumentedConn(tlsConn, func() {
n := atomic.AddUint64(c.openConns, ^uint64(0))
iConn := newInstrumentedConn(tlsConn, func() {
n := atomic.AddUint64(c.openConnsCount, ^uint64(0))
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
}, d.dialerID, cn.String()), nil
}, d.dialerID, cn.String())

// If this connection was opened using a Domain Name, then store it for later
// in case it needs to be forcibly closed.
if cn.HasDomainName() {
c.mu.Lock()
c.openConns = append(c.openConns, iConn)
c.mu.Unlock()
}
return iConn, nil
}

// removeCached stops all background refreshes and deletes the connection
Expand All @@ -406,7 +423,7 @@ func (d *Dialer) removeCached(
d.lock.Lock()
defer d.lock.Unlock()
c.Close()
delete(d.cache, i)
delete(d.cache, createKey(i))
}

// validClientCert checks that the ephemeral client certificate retrieved from
Expand Down Expand Up @@ -448,7 +465,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
}
ci, err := c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c, err)
d.removeCached(ctx, cn, c.connectionInfoCache, err)
return "", err
}
return ci.DBVersion, nil
Expand All @@ -472,7 +489,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
}
_, err = c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c, err)
d.removeCached(ctx, cn, c.connectionInfoCache, err)
}
return err
}
Expand All @@ -493,6 +510,8 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
type instrumentedConn struct {
net.Conn
closeFunc func()
mu sync.RWMutex
closed bool
dialerID string
connName string
}
Expand All @@ -517,9 +536,19 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
return bytesWritten, err
}

// isClosed returns true if this connection is closing or is already closed.
func (i *instrumentedConn) isClosed() bool {
i.mu.RLock()
defer i.mu.RUnlock()
return i.closed
}

// Close delegates to the underlying net.Conn interface and reports the close
// to the provided closeFunc only when Close returns no error.
func (i *instrumentedConn) Close() error {
i.mu.Lock()
defer i.mu.Unlock()
i.closed = true
err := i.Conn.Close()
if err != nil {
return err
Expand All @@ -546,55 +575,81 @@ func (d *Dialer) Close() error {
return nil
}

// createKey creates a key for the cache from an instance.ConnName.
// An instance.ConnName uniquely identifies a connection using
// project:region:instance + domainName. However, in the dialer cache,
// we want to to identify entries either by project:region:instance, or
// by domainName, but not the combination of the two.
func createKey(cn instance.ConnName) cacheKey {
if cn.HasDomainName() {
return cacheKey{domainName: cn.DomainName()}
}
return cacheKey{
name: cn.Name(),
project: cn.Project(),
region: cn.Region(),
}
}

// connectionInfoCache is a helper function for returning the appropriate
// connection info Cache in a threadsafe way. It will create a new cache,
// modify the existing one, or leave it unchanged as needed.
func (d *Dialer) connectionInfoCache(
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
) (monitoredCache, error) {
) (*monitoredCache, error) {
k := createKey(cn)

d.lock.RLock()
c, ok := d.cache[cn]
c, ok := d.cache[k]
d.lock.RUnlock()
if !ok {
d.lock.Lock()
defer d.lock.Unlock()
// Recheck to ensure instance wasn't created or changed between locks
c, ok = d.cache[cn]
if !ok {
var useIAMAuthNDial bool
if useIAMAuthN != nil {
useIAMAuthNDial = *useIAMAuthN
}
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
k, err := d.keyGenerator.rsaKey()
if err != nil {
return monitoredCache{}, err
}
var cache connectionInfoCache
if d.lazyRefresh {
cache = cloudsql.NewLazyRefreshCache(
cn,
d.logger,
d.sqladmin, k,
d.refreshTimeout, d.iamTokenSource,
d.dialerID, useIAMAuthNDial,
)
} else {
cache = cloudsql.NewRefreshAheadCache(
cn,
d.logger,
d.sqladmin, k,
d.refreshTimeout, d.iamTokenSource,
d.dialerID, useIAMAuthNDial,
)
}
var count uint64
c = monitoredCache{openConns: &count, connectionInfoCache: cache}
d.cache[cn] = c
}

if ok && !c.isClosed() {
c.UpdateRefresh(useIAMAuthN)
return c, nil
}

d.lock.Lock()
defer d.lock.Unlock()

// Recheck to ensure instance wasn't created or changed between locks
c, ok = d.cache[k]

// c exists and is not closed
if ok && !c.isClosed() {
c.UpdateRefresh(useIAMAuthN)
return c, nil
}

c.UpdateRefresh(useIAMAuthN)
// Create a new instance of monitoredCache
var useIAMAuthNDial bool
if useIAMAuthN != nil {
useIAMAuthNDial = *useIAMAuthN
}
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
rsaKey, err := d.keyGenerator.rsaKey()
if err != nil {
return nil, err
}
var cache connectionInfoCache
if d.lazyRefresh {
cache = cloudsql.NewLazyRefreshCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.dialerID, useIAMAuthNDial,
)
} else {
cache = cloudsql.NewRefreshAheadCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.dialerID, useIAMAuthNDial,
)
}
c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger)
d.cache[k] = c

return c, nil
}
Loading

0 comments on commit 4d7abd8

Please sign in to comment.