Skip to content

Commit 9383c57

Browse files
committed
use helper method to terminate the receive loop on forceful websocket shutdown
1 parent bf2088b commit 9383c57

File tree

4 files changed

+31
-34
lines changed

4 files changed

+31
-34
lines changed

Common/Websocket/WebsockBaseController.cs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public abstract class WebsocketBaseController<T> : OpenShockControllerBase, IAsy
3434
private CancellationTokenSource? _linkedSource;
3535

3636
protected CancellationToken LinkedToken;
37-
37+
3838
/// <summary>
3939
/// Channel for multithreading thread safety of the websocket, MessageLoop is the only reader for this channel
4040
/// </summary>
@@ -53,7 +53,6 @@ protected WebsocketBaseController(ILogger<WebsocketBaseController<T>> logger)
5353
Logger = logger;
5454
}
5555

56-
5756
/// <inheritdoc />
5857
[NonAction]
5958
public ValueTask QueueMessage(T data) => Channel.Writer.WriteAsync(data, LinkedToken);
@@ -200,14 +199,18 @@ protected virtual Task SendWebSocketMessage(T message, WebSocket websocket, Canc
200199

201200
#endregion
202201

202+
private CancellationTokenSource _receiveCancellationTokenSource = new();
203+
203204
/// <summary>
204205
/// Main receiver logic for the websocket
205206
/// </summary>
206207
/// <returns></returns>
207208
[NonAction]
208209
private async Task Logic()
209210
{
210-
while (!LinkedToken.IsCancellationRequested)
211+
using var linkedReceiverToken = CancellationTokenSource.CreateLinkedTokenSource(LinkedToken, _receiveCancellationTokenSource.Token);
212+
213+
while (!linkedReceiverToken.IsCancellationRequested)
211214
{
212215
try
213216
{
@@ -231,7 +234,7 @@ private async Task Logic()
231234
return;
232235
}
233236

234-
if (!await HandleReceive())
237+
if (!await HandleReceive(linkedReceiverToken.Token))
235238
{
236239
// HandleReceive returned false, we will close the connection after this
237240
Logger.LogDebug("HandleReceive returned false, closing connection");
@@ -263,7 +266,18 @@ private async Task Logic()
263266
/// </summary>
264267
/// <returns>True if you want to continue the receiver loop, false if you want to terminate</returns>
265268
[NonAction]
266-
protected abstract Task<bool> HandleReceive();
269+
protected abstract Task<bool> HandleReceive(CancellationToken cancellationToken);
270+
271+
[NonAction]
272+
protected async Task ForceClose(WebSocketCloseStatus closeStatus, string? statusDescription)
273+
{
274+
await _receiveCancellationTokenSource.CancelAsync();
275+
276+
if (WebSocket is { State: WebSocketState.CloseReceived or WebSocketState.Open })
277+
{
278+
await WebSocket.CloseOutputAsync(closeStatus, statusDescription, LinkedToken);
279+
}
280+
}
267281

268282
/// <summary>
269283
/// Send initial data to the client

LiveControlGateway/Controllers/HubControllerBase.cs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,7 @@ ILogger<FlatbuffersWebsocketBaseController<TIn, TOut>> logger
102102
try
103103
{
104104
Logger.LogInformation("Keep alive timeout reached, closing websocket connection");
105-
106-
if (WebSocket is { State: WebSocketState.Open })
107-
{
108-
await WebSocket!.CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Keep alive timeout reached",
109-
LinkedToken);
110-
}
111-
105+
await ForceClose(WebSocketCloseStatus.ProtocolError, "Keep alive timeout reached");
112106
WebSocket?.Abort();
113107
}
114108
catch (Exception ex)
@@ -187,14 +181,7 @@ public async Task DisconnectOld()
187181
if (WebSocket == null)
188182
return;
189183

190-
if (WebSocket is { State: WebSocketState.Open })
191-
{
192-
await WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure,
193-
"Hub is connecting from a different location",
194-
LinkedToken);
195-
}
196-
197-
WebSocket?.Abort();
184+
await ForceClose(WebSocketCloseStatus.NormalClosure, "Hub is connecting from a different location");
198185
}
199186

200187
private static DateTimeOffset? GetBootedAtFromUptimeMs(ulong uptimeMs)

LiveControlGateway/Controllers/LiveControlController.cs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ await QueueMessage(new LiveControlResponse<LiveResponseType>
255255
}
256256

257257
/// <inheritdoc />
258-
protected override async Task<bool> HandleReceive()
258+
protected override async Task<bool> HandleReceive(CancellationToken cancellationToken)
259259
{
260260
var message =
261261
await JsonWebSocketUtils.ReceiveFullMessageAsyncNonAlloc<BaseRequest<LiveRequestType>>(WebSocket!,
@@ -266,21 +266,18 @@ await JsonWebSocketUtils.ReceiveFullMessageAsyncNonAlloc<BaseRequest<LiveRequest
266266
if (request?.Data == null)
267267
{
268268
Logger.LogWarning("Received null data from client");
269-
await WebSocket!.CloseOutputAsync(WebSocketCloseStatus.InvalidPayloadData,
270-
"Invalid json message received", LinkedToken);
269+
await ForceClose(WebSocketCloseStatus.InvalidPayloadData, "Invalid json message received");
271270
return false;
272271
}
273272

274-
#pragma warning disable CS4014
275-
OsTask.Run(() => ProcessResult(request));
276-
#pragma warning restore CS4014
273+
await ProcessResult(request);
274+
277275
return true;
278276
},
279277
async failed =>
280278
{
281279
Logger.LogWarning(failed.Exception, "Deserialization failed for websocket message");
282-
await WebSocket!.CloseOutputAsync(WebSocketCloseStatus.InvalidPayloadData,
283-
"Invalid json message received", LinkedToken);
280+
await ForceClose(WebSocketCloseStatus.InvalidPayloadData, "Invalid json message received");
284281
return false;
285282
}, closure =>
286283
{
@@ -551,8 +548,8 @@ await SendWebSocketMessage(new LiveControlResponse<LiveResponseType>
551548
ResponseType = LiveResponseType.DeviceNotConnected,
552549
}, WebSocket!, LinkedToken);
553550

554-
await WebSocket!.CloseOutputAsync(WebSocketCloseStatus.NormalClosure,
555-
"Hub is disconnected", LinkedToken);
551+
552+
await ForceClose(WebSocketCloseStatus.NormalClosure, "Hub is disconnected");
556553
}
557554
catch (Exception e)
558555
{

LiveControlGateway/Websocket/FlatbuffersWebsocketBaseController.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,17 @@ protected override Task
4848
protected abstract Task<bool> Handle(TIn data);
4949

5050
/// <inheritdoc />
51-
protected override async Task<bool> HandleReceive()
51+
protected override async Task<bool> HandleReceive(CancellationToken cancellationToken)
5252
{
5353
var message =
5454
await FlatbufferWebSocketUtils.ReceiveFullMessageAsyncNonAlloc(WebSocket!,
55-
_incomingSerializer, LinkedToken);
55+
_incomingSerializer, cancellationToken);
5656

5757
var continueLoop = await message.Match(
5858
Handle,
5959
async _ =>
6060
{
61-
await WebSocket!.CloseAsync(WebSocketCloseStatus.InvalidPayloadData, "Invalid flatbuffers message",
62-
LinkedToken);
61+
await ForceClose(WebSocketCloseStatus.InvalidPayloadData, "Invalid flatbuffers message");
6362
return false;
6463
},
6564
_ =>

0 commit comments

Comments
 (0)