Skip to content

Commit

Permalink
Specify a custom dial function per config
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronjheng committed Dec 16, 2023
1 parent 0004702 commit 2d38b27
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
11 changes: 10 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {

if c.cfg.DialFunc != nil {
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr)
} else if ok {
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
Expand Down
15 changes: 15 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
Expand Down Expand Up @@ -65,6 +66,15 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections

// DialFunc specifies the dial function for creating connections.
// If DialFunc is nil, the connector will attempt to find a dial function from the global registry (registered with RegisterDialContext).
// If no dial function is found even after checking the global registry, the net.Dialer will be used as a fallback.
//
// The dial function is responsible for establishing connections. By providing a custom dial function,
// users can flexibly control the process of connection establishment. Custom dial functions can be registered in the global registry
// to tailor connection behavior according to specific requirements.
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
}

// NewConfig creates a new Config and sets default values.
Expand Down Expand Up @@ -95,6 +105,11 @@ func (cfg *Config) Clone() *Config {
E: cfg.pubKey.E,
}
}

if cp.DialFunc != nil {
cp.DialFunc = cfg.DialFunc
}

return &cp
}

Expand Down

0 comments on commit 2d38b27

Please sign in to comment.