Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ok packet represents eof #1153

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,15 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
// FIXME - seems like a bug in MySQL (or it's intended).
// There's no EOF return after parameters.
// However, this behavior isn't consistent to Maria DB.
if mc.flags&clientDeprecateEOF == 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems MySQL behavior isn't consistent to the documentation here: https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html

However, Maria DB follows this specification.

IMO this feature should be held until there's a clear solution to be compatible with supported vendors.

Copy link
Contributor

@shiyuhang0 shiyuhang0 Nov 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I continue your excellent work? Does this problem be solved?

if err = mc.readExactPackets(stmt.paramCount); err != nil {
return nil, err
}
}
Expand Down
98 changes: 74 additions & 24 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 1 + 2

// capability flags (upper 2 bytes) [2 bytes]
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += 1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
Expand Down Expand Up @@ -286,6 +291,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
mc.flags&clientDeprecateEOF |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -608,20 +614,21 @@ func readStatus(b []byte) statusFlag {
}

// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
func (mc *mysqlConn) handleOkPacket(data []byte) error {
var n, m int

// 0x00 [1 byte]

// 0x00 or 0xFE [1 byte]
n := 1
var l int
// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
mc.affectedRows, _, l = readLengthEncodedInteger(data[n:])
n += l

// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
mc.insertId, _, l = readLengthEncodedInteger(data[n:])
n += l

// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
mc.status = readStatus(data[n : n+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
Expand All @@ -631,19 +638,36 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
return nil
}

// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet
// acting as an EOF.
func (mc *mysqlConn) isEOFPacket(data []byte) bool {
// Legacy EOF packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) && mc.flags&clientDeprecateEOF == 0 {
return true
}
return data[0] == iEOF && len(data) < 9 && mc.flags&clientDeprecateEOF != 0
}

// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)

for i := 0; ; i++ {
// If we set clientDeprecateEOF capability flag,
// the EOF will be no longer sent after all columns.
packets := count
if mc.flags&clientDeprecateEOF == 0 {
// Legacy way, read one more EOF packet.
packets += 1
}

for i := 0; i < packets; i++ {
data, err := mc.readPacket()
if err != nil {
return nil, err
}

// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if mc.isEOFPacket(data) {
if i == count {
return columns, nil
}
Expand Down Expand Up @@ -729,9 +753,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
//}
}
return columns, nil
}

// Read Packets as Field Packets until EOF-Packet or an Error appears
// Read Packets as Field Packets until EOF/OK-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
Expand All @@ -746,9 +771,16 @@ func (rows *textRows) readRow(dest []driver.Value) error {
}

// EOF Packet
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
if mc.isEOFPacket(data) {
if mc.flags&clientDeprecateEOF == 0 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
} else {
if err := mc.handleOkPacket(data); err != nil {
rows.mc = nil
return err
}
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
Expand Down Expand Up @@ -808,16 +840,27 @@ func (mc *mysqlConn) readUntilEOF() error {
return err
}

switch data[0] {
case iERR:
switch {
case data[0] == iERR:
return mc.handleErrorPacket(data)
case iEOF:
if len(data) == 5 {
case mc.isEOFPacket(data):
if mc.flags&clientDeprecateEOF == 0 {
mc.status = readStatus(data[3:])
return nil
}
return nil
return mc.handleOkPacket(data)
}
}
}

func (mc *mysqlConn) readExactPackets(num int) error {
for i := 0; i < num; i++ {
_, err := mc.readPacket()
if err != nil {
return err
}
}
return nil
}

/******************************************************************************
Expand Down Expand Up @@ -1178,15 +1221,22 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {

// packet indicator [1 byte]
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
if rows.mc.isEOFPacket(data) {
if rows.mc.flags&clientDeprecateEOF == 0 {
rows.mc.status = readStatus(data[3:])
} else {
if err := rows.mc.handleOkPacket(data); err != nil {
rows.mc = nil
return err
}
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
return io.EOF
}

mc := rows.mc
rows.mc = nil

Expand Down
1 change: 0 additions & 1 deletion rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ func (rows *textRows) Next(dest []driver.Value) error {
if err := mc.error(); err != nil {
return err
}

// Fetch next row from stream
return rows.readRow(dest)
}
Expand Down
3 changes: 1 addition & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {

if resLen > 0 {
// Columns
if err = mc.readUntilEOF(); err != nil {
if err = mc.readExactPackets(resLen); err != nil {
return nil, err
}

// Rows
if err := mc.readUntilEOF(); err != nil {
return nil, err
Expand Down