Skip to content

Commit ee9af81

Browse files
committed
Only reset "is timed out" flag in public API methods.
This fixes a bug where a timed-out query would throw QueryInterrupted instead of the expected CommandTimeoutExpired because MySqlCommand.ExecuteReader would return a valid MySqlDataReader object, and the first call to MySqlDataReader.Read would clear the IsTimedOut flag (before a subsequent call to Read processed the pending error packet).
1 parent d8ced8d commit ee9af81

File tree

3 files changed

+77
-12
lines changed

3 files changed

+77
-12
lines changed

src/MySqlConnector/MySqlCommand.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ public override Task<int> ExecuteNonQueryAsync(CancellationToken cancellationTok
260260

261261
internal async Task<int> ExecuteNonQueryAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
262262
{
263+
Volatile.Write(ref m_commandTimedOut, false);
263264
this.ResetCommandTimeout();
264265
using var registration = ((ICancellableCommand) this).RegisterCancel(cancellationToken);
265266
using var reader = await ExecuteReaderNoResetTimeoutAsync(CommandBehavior.Default, ioBehavior, cancellationToken).ConfigureAwait(false);
@@ -277,6 +278,7 @@ internal async Task<int> ExecuteNonQueryAsync(IOBehavior ioBehavior, Cancellatio
277278

278279
internal async Task<object?> ExecuteScalarAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
279280
{
281+
Volatile.Write(ref m_commandTimedOut, false);
280282
this.ResetCommandTimeout();
281283
using var registration = ((ICancellableCommand) this).RegisterCancel(cancellationToken);
282284
var hasSetResult = false;
@@ -306,6 +308,7 @@ protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBeha
306308

307309
internal async Task<MySqlDataReader> ExecuteReaderAsync(CommandBehavior behavior, IOBehavior ioBehavior, CancellationToken cancellationToken)
308310
{
311+
Volatile.Write(ref m_commandTimedOut, false);
309312
this.ResetCommandTimeout();
310313
using var registration = ((ICancellableCommand) this).RegisterCancel(cancellationToken);
311314
return await ExecuteReaderNoResetTimeoutAsync(behavior, ioBehavior, cancellationToken).ConfigureAwait(false);
@@ -364,8 +367,6 @@ public override ValueTask DisposeAsync()
364367

365368
void ICancellableCommand.SetTimeout(int milliseconds)
366369
{
367-
Volatile.Write(ref m_commandTimedOut, false);
368-
369370
if (m_cancelTimerId != 0)
370371
TimerQueue.Instance.Remove(m_cancelTimerId);
371372

tests/MySqlConnector.Tests/CancellationTests.cs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public void Execute(int step, int method)
5353

5454
[SkipCITheory]
5555
[MemberData(nameof(GetAsyncMethodSteps))]
56-
public async Task ExecuteAsyncs(int step, int method)
56+
public async Task ExecuteAsync(int step, int method)
5757
{
5858
using var connection = new MySqlConnection(m_csb.ConnectionString);
5959
connection.Open();
@@ -74,6 +74,53 @@ public async Task ExecuteAsyncs(int step, int method)
7474
}
7575
}
7676

77+
public class CancelBufferedWithCommandTimeout : CancellationTests
78+
{
79+
[SkipCITheory]
80+
[MemberData(nameof(GetSyncMethodSteps))]
81+
public void Execute(int step, int method)
82+
{
83+
using var connection = new MySqlConnection(m_csb.ConnectionString);
84+
connection.Open();
85+
using var command = connection.CreateCommand();
86+
command.CommandTimeout = 1;
87+
command.CommandText = $"SELECT 0, 4000, {step}, 2;";
88+
var stopwatch = Stopwatch.StartNew();
89+
var ex = Assert.Throws<MySqlException>(() => s_executeMethods[method](command));
90+
Assert.InRange(stopwatch.ElapsedMilliseconds, 900, 1500);
91+
Assert.Equal(MySqlErrorCode.CommandTimeoutExpired, ex.ErrorCode);
92+
var inner = Assert.IsType<MySqlException>(ex.InnerException);
93+
Assert.Equal(MySqlErrorCode.QueryInterrupted, inner.ErrorCode);
94+
95+
// connection should still be usable
96+
Assert.Equal(ConnectionState.Open, connection.State);
97+
command.CommandText = "SELECT 1;";
98+
Assert.Equal(1, command.ExecuteScalar());
99+
}
100+
101+
[SkipCITheory]
102+
[MemberData(nameof(GetAsyncMethodSteps))]
103+
public async Task ExecuteAsync(int step, int method)
104+
{
105+
using var connection = new MySqlConnection(m_csb.ConnectionString);
106+
connection.Open();
107+
using var command = connection.CreateCommand();
108+
command.CommandTimeout = 1;
109+
command.CommandText = $"SELECT 0, 4000, {step}, 2;";
110+
var stopwatch = Stopwatch.StartNew();
111+
var ex = await Assert.ThrowsAsync<MySqlException>(async () => await s_executeAsyncMethods[method](command, default));
112+
Assert.InRange(stopwatch.ElapsedMilliseconds, 900, 1500);
113+
Assert.Equal(MySqlErrorCode.CommandTimeoutExpired, ex.ErrorCode);
114+
var inner = Assert.IsType<MySqlException>(ex.InnerException);
115+
Assert.Equal(MySqlErrorCode.QueryInterrupted, inner.ErrorCode);
116+
117+
// connection should still be usable
118+
Assert.Equal(ConnectionState.Open, connection.State);
119+
command.CommandText = "SELECT 1;";
120+
Assert.Equal(1, command.ExecuteScalar());
121+
}
122+
}
123+
77124
public class CancelWithCancel : CancellationTests
78125
{
79126
[SkipCITheory]

tests/MySqlConnector.Tests/FakeMySqlServerConnection.cs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ public async Task RunAsync(TcpClient client, CancellationToken token)
9797
var pauseStep = int.Parse(match.Groups[3].Value);
9898
var flags = int.Parse(match.Groups[4].Value);
9999
var ignoreCancellation = (flags & 1) == 1;
100+
var bufferOutput = (flags & 2) == 2;
100101

101102
var data = new byte[number.Length + 1];
102103
data[0] = (byte) number.Length;
@@ -120,18 +121,34 @@ public async Task RunAsync(TcpClient client, CancellationToken token)
120121
new byte[] { 0xFE, 0, 0, 2, 0 }, // EOF
121122
};
122123

123-
var queryInterrupted = false;
124-
for (var step = 1; step < packets.Length && !queryInterrupted; step++)
124+
if (bufferOutput)
125125
{
126-
if (pauseStep == step || pauseStep == -1)
126+
// if 'bufferOutput' is set, perform the delay immediately then send all the output afterwards, as though it were buffered on the server
127+
var queryInterrupted = false;
128+
if (ignoreCancellation)
129+
await Task.Delay(delay, token);
130+
else
131+
queryInterrupted = CancelQueryEvent.Wait(delay, token);
132+
133+
for (var step = 1; step < pauseStep; step++)
134+
await SendAsync(stream, step, x => x.Write(packets[step]));
135+
await SendAsync(stream, pauseStep, x => x.Write(packets[queryInterrupted ? 0 : pauseStep]));
136+
}
137+
else
138+
{
139+
var queryInterrupted = false;
140+
for (var step = 1; step < packets.Length && !queryInterrupted; step++)
127141
{
128-
if (ignoreCancellation)
129-
await Task.Delay(delay, token);
130-
else
131-
queryInterrupted = CancelQueryEvent.Wait(delay, token);
132-
}
142+
if (pauseStep == step || pauseStep == -1)
143+
{
144+
if (ignoreCancellation)
145+
await Task.Delay(delay, token);
146+
else
147+
queryInterrupted = CancelQueryEvent.Wait(delay, token);
148+
}
133149

134-
await SendAsync(stream, step, x => x.Write(packets[queryInterrupted ? 0 : step]));
150+
await SendAsync(stream, step, x => x.Write(packets[queryInterrupted ? 0 : step]));
151+
}
135152
}
136153
}
137154
else if ((match = Regex.Match(query, @"^KILL QUERY ([0-9]+)(;|$)", RegexOptions.IgnoreCase)).Success)

0 commit comments

Comments
 (0)