Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional resultset metadata #1150

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Tan Jinhua <312841925 at qq.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Tzu-Chiao Yeh <su3g4284zo6y7 at gmail.com>
Vladimir Kovpak <cn007b at gmail.com>
Vladyslav Zhelezniak <zhvladi at gmail.com>
Xiangyu Hu <xiangyu.hu at outlook.com>
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ Examples:
* `autocommit=1`: `SET autocommit=1`
* [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'`
* [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'`
* metata=none`](https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_resultset_metadata): `SET resultset_metadata=none` (note that this is only applicable to MySQL 8.0+ versions).


#### Examples
Expand Down
49 changes: 35 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,22 @@ import (
)

type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
optionalResultSetMetadata bool
resultSetMetadata uint8

// for context support (Go 1.8+)
watching bool
Expand Down Expand Up @@ -392,6 +394,10 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
}
}

if mc.optionalResultSetMetadata && mc.resultSetMetadata == resultSetMetadataNone {
return mc.readIgnoreColumns(rows, resLen)
}

// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
Expand All @@ -400,6 +406,21 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
return nil, mc.markBadConn(err)
}

func (mc *mysqlConn) readIgnoreColumns(rows *textRows, resLen int) (*textRows, error) {
data, err := mc.readPacket()
if err != nil {
errLog.Print(err)
return nil, err
}
// Expected an EOF packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
// Set empty columnNames, we will first read these columnNames via rows.Columns().
rows.rs.columnNames = make([]string, resLen)
return rows, nil
}
return nil, ErrOptionalResultSetMetadataPkt
}

// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
Expand Down
15 changes: 15 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"context"
"database/sql/driver"
"net"
"strings"
)

type connector struct {
Expand Down Expand Up @@ -88,6 +89,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
plugin = defaultAuthPlugin
}

// Set the optionalResultSetMetadata ahead to set the client capability flag.
if resultSetMetadata, ok := mc.cfg.Params["resultset_metadata"]; ok {
upperVal := strings.ToUpper(resultSetMetadata)
switch upperVal {
case resultSetMetadataSysVarNone:
mc.optionalResultSetMetadata = true
mc.resultSetMetadata = resultSetMetadataNone
case resultSetMetadataSysVarFull:
mc.optionalResultSetMetadata = true
mc.resultSetMetadata = resultSetMetadataFull
}
// To be consistent with other params, in case the param is passed wrongly still send to MySQL to let the server side rejects it.
}

// Send Client Authentication Packet
authResp, err := mc.auth(authData, plugin)
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const (
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
clientOptionalResultSetMetadata
)

const (
Expand Down Expand Up @@ -172,3 +173,16 @@ const (
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)

const (
// One-byte metadata flag
// https://dev.mysql.com/worklog/task/?id=8134
resultSetMetadataNone uint8 = iota
resultSetMetadataFull
)

const (
// ResultSet Metadata system var
resultSetMetadataSysVarNone = "NONE"
resultSetMetadataSysVarFull = "FULL"
)
45 changes: 45 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ var (
prot string
addr string
dbname string
vendor string
dsn string
netAddr string
available bool
Expand Down Expand Up @@ -202,6 +203,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
func maybeSkip(t *testing.T, err error, skipErrno uint16) {
mySQLErr, ok := err.(*MySQLError)
if !ok {
errLog.Print("non match")
return
}

Expand Down Expand Up @@ -1345,6 +1347,49 @@ func TestFoundRows(t *testing.T) {
})
}

func TestOptionalResultSetMetadata(t *testing.T) {
runTests(t, dsn+"&resultset_metadata=none", func(dbt *DBTest) {
_, err := dbt.db.Exec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
if err == ErrNoOptionalResultMetadataSet {
t.Skip("server does not support resultset metadata")
} else if err != nil {
dbt.Fatal(err)
}
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")

row := dbt.db.QueryRow("SELECT id, data FROM test WHERE id = 1")
id, data := 0, 0
err = row.Scan(&id, &data)
if err != nil {
dbt.Fatal(err)
}

if id != 1 && data != 0 {
dbt.Fatal("invalid result")
}
})
runTests(t, dsn+"&resultset_metadata=full", func(dbt *DBTest) {
_, err := dbt.db.Exec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
if err == ErrNoOptionalResultMetadataSet {
t.Skip("server does not support resultset metadata")
} else if err != nil {
dbt.Fatal(err)
}
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")

row := dbt.db.QueryRow("SELECT id, data FROM test WHERE id = 1")
id, data := 0, 0
err = row.Scan(&id, &data)
if err != nil {
dbt.Fatal(err)
}

if id != 1 && data != 0 {
dbt.Fatal("invalid result")
}
})
}

func TestTLS(t *testing.T) {
tlsTestReq := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
Expand Down
32 changes: 16 additions & 16 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
Expand Down
3 changes: 3 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ var testDSNs = []struct {
}, {
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false},
}, {
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false},
}, {
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
Expand Down
26 changes: 14 additions & 12 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@ import (

// Various errors the driver might return. Can change between driver versions.
var (
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication.")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
ErrNoOptionalResultMetadataSet = errors.New("requested optional resultset metadata but server does not support")
ErrOptionalResultSetMetadataPkt = errors.New("malformed optional resultset metadata packets")

// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
Expand Down
25 changes: 24 additions & 1 deletion packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,18 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 1 + 2
// capability flags (upper 2 bytes) [2 bytes]
upperFlags := clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
mc.flags |= upperFlags << 16
pos += 2
if mc.flags&clientOptionalResultSetMetadata == 0 && mc.optionalResultSetMetadata {
return nil, "", ErrNoOptionalResultMetadataSet
}

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += 1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
Expand Down Expand Up @@ -300,6 +308,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientFlags |= clientMultiStatements
}

if mc.optionalResultSetMetadata {
clientFlags |= clientOptionalResultSetMetadata
}

// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLen := len(authResp)
Expand Down Expand Up @@ -554,6 +566,17 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
return int(num), nil
}

// Sniff one extra byte for resultset metadata if we set capability
// CLIENT_OPTIONAL_RESULTSET_METADTA
// https://dev.mysql.com/worklog/task/?id=8134
if len(data) == 2 && mc.flags&clientOptionalResultSetMetadata != 0 {
// ResultSet metadata flag check
if mc.resultSetMetadata != data[1] {
return 0, ErrOptionalResultSetMetadataPkt
}
return int(num), nil
}

return 0, ErrMalformPkt
}
return 0, err
Expand Down