Skip to content

Commit adce95a

Browse files
committed
Added read-buffers to udp packet reader
1 parent 869c0d8 commit adce95a

File tree

2 files changed

+98
-35
lines changed

2 files changed

+98
-35
lines changed

network-api/network-api.go

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func Register(router *msgpackrouter.Router) {
4848

4949
_ = router.RegisterMethod("udp/connect", udpConnect)
5050
_ = router.RegisterMethod("udp/write", udpWrite)
51+
_ = router.RegisterMethod("udp/awaitRead", udpAwaitRead)
5152
_ = router.RegisterMethod("udp/read", udpRead)
5253
_ = router.RegisterMethod("udp/close", udpClose)
5354
}
@@ -56,6 +57,7 @@ var lock sync.RWMutex
5657
var liveConnections = make(map[uint]net.Conn)
5758
var liveListeners = make(map[uint]net.Listener)
5859
var liveUdpConnections = make(map[uint]net.PacketConn)
60+
var udpReadBuffers = make(map[uint][]byte)
5961
var nextConnectionID atomic.Uint32
6062

6163
// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
@@ -420,39 +422,35 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
420422
}
421423
}
422424

423-
func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
424-
if len(params) != 2 && len(params) != 3 {
425-
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read[, optional timeout in ms])"}
425+
func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
426+
if len(params) != 1 && len(params) != 2 {
427+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
426428
}
427429
id, ok := msgpackrpc.ToUint(params[0])
428430
if !ok {
429431
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
430432
}
431-
lock.RLock()
432-
udpConn, ok := liveUdpConnections[id]
433-
lock.RUnlock()
434-
if !ok {
435-
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
436-
}
437-
maxBytes, ok := msgpackrpc.ToUint(params[1])
438-
if !ok {
439-
return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"}
440-
}
441433
var deadline time.Time // default value == no timeout
442-
if len(params) == 2 {
434+
if len(params) == 1 {
443435
// No timeout
444-
} else if ms, ok := msgpackrpc.ToInt(params[2]); !ok {
436+
} else if ms, ok := msgpackrpc.ToInt(params[1]); !ok {
445437
return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"}
446438
} else if ms > 0 {
447439
deadline = time.Now().Add(time.Duration(ms) * time.Millisecond)
448440
} else if ms == 0 {
449441
// No timeout
450442
}
451443

444+
lock.RLock()
445+
udpConn, ok := liveUdpConnections[id]
446+
lock.RUnlock()
447+
if !ok {
448+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
449+
}
452450
if err := udpConn.SetReadDeadline(deadline); err != nil {
453451
return nil, []any{3, "Failed to set read deadline: " + err.Error()}
454452
}
455-
buffer := make([]byte, maxBytes)
453+
buffer := make([]byte, 64*1024) // 64 KB buffer
456454
n, addr, err := udpConn.ReadFrom(buffer)
457455
if errors.Is(err, os.ErrDeadlineExceeded) {
458456
// timeout
@@ -471,7 +469,41 @@ func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_re
471469
// Should never fail, but...
472470
return nil, []any{4, "Failed to parse source address: " + err.Error()}
473471
}
474-
return []any{buffer[:n], host, port}, nil
472+
473+
lock.Lock()
474+
udpReadBuffers[id] = buffer[:n]
475+
lock.Unlock()
476+
return []any{n, host, port}, nil
477+
}
478+
479+
func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
480+
if len(params) != 2 && len(params) != 3 {
481+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"}
482+
}
483+
id, ok := msgpackrpc.ToUint(params[0])
484+
if !ok {
485+
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
486+
}
487+
maxBytes, ok := msgpackrpc.ToUint(params[1])
488+
if !ok {
489+
return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"}
490+
}
491+
492+
lock.Lock()
493+
buffer, exists := udpReadBuffers[id]
494+
n := uint(len(buffer))
495+
if exists {
496+
// keep the remainder of the buffer for the next read
497+
if n > maxBytes {
498+
udpReadBuffers[id] = buffer[maxBytes:]
499+
n = maxBytes
500+
} else {
501+
udpReadBuffers[id] = nil
502+
}
503+
}
504+
lock.Unlock()
505+
506+
return buffer[:n], nil
475507
}
476508

477509
func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
@@ -485,9 +517,8 @@ func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
485517

486518
lock.Lock()
487519
udpConn, existsConn := liveUdpConnections[id]
488-
if existsConn {
489-
delete(liveUdpConnections, id)
490-
}
520+
delete(liveUdpConnections, id)
521+
delete(udpReadBuffers, id)
491522
lock.Unlock()
492523

493524
if !existsConn {

network-api/network-api_test.go

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,13 @@ func TestUDPNetworkAPI(t *testing.T) {
253253
require.Equal(t, 5, res)
254254
}
255255
{
256-
res, err := udpRead(ctx, nil, []any{conn2, 100})
256+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
257257
require.Nil(t, err)
258-
require.Equal(t, []any{[]uint8("Hello"), "127.0.0.1", 9800}, res)
258+
require.Equal(t, []any{5, "127.0.0.1", 9800}, res)
259+
260+
res2, err := udpRead(ctx, nil, []any{conn2, 100})
261+
require.Nil(t, err)
262+
require.Equal(t, []uint8("Hello"), res2)
259263
}
260264
{
261265
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
@@ -268,14 +272,22 @@ func TestUDPNetworkAPI(t *testing.T) {
268272
require.Equal(t, 3, res)
269273
}
270274
{
271-
res, err := udpRead(ctx, nil, []any{conn2, 100})
275+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
276+
require.Nil(t, err)
277+
require.Equal(t, []any{3, "127.0.0.1", 9800}, res)
278+
279+
res2, err := udpRead(ctx, nil, []any{conn2, 100})
272280
require.Nil(t, err)
273-
require.Equal(t, []any{[]uint8("One"), "127.0.0.1", 9800}, res)
281+
require.Equal(t, []uint8("One"), res2)
274282
}
275283
{
276-
res, err := udpRead(ctx, nil, []any{conn2, 100})
284+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
277285
require.Nil(t, err)
278-
require.Equal(t, []any{[]uint8("Two"), "127.0.0.1", 9800}, res)
286+
require.Equal(t, []any{3, "127.0.0.1", 9800}, res)
287+
288+
res2, err := udpRead(ctx, nil, []any{conn2, 100})
289+
require.Nil(t, err)
290+
require.Equal(t, []uint8("Two"), res2)
279291
}
280292
{
281293
res, err := udpClose(ctx, nil, []any{conn1})
@@ -304,9 +316,17 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
304316
require.Equal(t, 5, res)
305317
}
306318
{
307-
res, err := udpRead(ctx, nil, []any{conn2, 100})
319+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
320+
require.Nil(t, err)
321+
require.Equal(t, 5, res.([]any)[0])
322+
323+
res2, err := udpRead(ctx, nil, []any{conn2, 2})
308324
require.Nil(t, err)
309-
require.Equal(t, []uint8("Hello"), res.([]any)[0])
325+
require.Equal(t, []uint8("He"), res2)
326+
327+
res2, err = udpRead(ctx, nil, []any{conn2, 20})
328+
require.Nil(t, err)
329+
require.Equal(t, []uint8("llo"), res2)
310330
}
311331
{
312332
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
@@ -319,14 +339,22 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
319339
require.Equal(t, 3, res)
320340
}
321341
{
322-
res, err := udpRead(ctx, nil, []any{conn2, 100})
342+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
343+
require.Nil(t, err)
344+
require.Equal(t, 3, res.([]any)[0])
345+
346+
res2, err := udpRead(ctx, nil, []any{conn2, 100})
323347
require.Nil(t, err)
324-
require.Equal(t, []uint8("One"), res.([]any)[0])
348+
require.Equal(t, []uint8("One"), res2)
325349
}
326350
{
327-
res, err := udpRead(ctx, nil, []any{conn2, 100})
351+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
328352
require.Nil(t, err)
329-
require.Equal(t, []uint8("Two"), res.([]any)[0])
353+
require.Equal(t, 3, res.([]any)[0])
354+
355+
res2, err := udpRead(ctx, nil, []any{conn2, 100})
356+
require.Nil(t, err)
357+
require.Equal(t, []uint8("Two"), res2)
330358
}
331359

332360
// Check timeouts
@@ -338,15 +366,19 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
338366
}()
339367
{
340368
start := time.Now()
341-
res, err := udpRead(ctx, nil, []any{conn2, 100, 10})
369+
res, err := udpAwaitRead(ctx, nil, []any{conn2, 10})
342370
require.Less(t, time.Since(start), 20*time.Millisecond)
343371
require.Equal(t, []any{5, "Timeout"}, err)
344372
require.Nil(t, res)
345373
}
346374
{
347-
res, err := udpRead(ctx, nil, []any{conn2, 100, 0})
375+
res, err := udpAwaitRead(ctx, nil, []any{conn2, 0})
376+
require.Nil(t, err)
377+
require.Equal(t, 5, res.([]any)[0])
378+
379+
res2, err := udpRead(ctx, nil, []any{conn2, 100, 0})
348380
require.Nil(t, err)
349-
require.Equal(t, []uint8("Three"), res.([]any)[0])
381+
require.Equal(t, []uint8("Three"), res2)
350382
}
351383

352384
{

0 commit comments

Comments
 (0)