Skip to content

Commit

Permalink
Restructuring cdp class
Browse files Browse the repository at this point in the history
  • Loading branch information
ShortDevelopment committed Jan 1, 2025
1 parent c29187b commit 3390897
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 127 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using ShortDev.Microsoft.ConnectedDevices.Messages;
using ShortDev.Microsoft.ConnectedDevices.Transports;
using System.Buffers;

namespace ShortDev.Microsoft.ConnectedDevices;

partial class ConnectedDevicesPlatform
{
static readonly ArrayPool<byte> _messagePool = ArrayPool<byte>.Create();
private void ReceiveLoop(CdpSocket socket)
{
RegisterKnownSocket(socket);
Task.Run(() =>
{
EndianReader streamReader = new(Endianness.BigEndian, socket.InputStream);
using (socket)
{
ReceiveLoop(socket, ref streamReader);
}
});
}

void ReceiveLoop(CdpSocket socket, ref EndianReader streamReader)
{
do
{
CdpSession? session = null;
try
{
var header = CommonHeader.Parse(ref streamReader);

if (socket.IsClosed)
return;

session = CdpSession.GetOrCreate(
this,
socket.Endpoint,
header
);

using var payload = _messagePool.RentToken(header.PayloadSize);
streamReader.ReadBytes(payload.Span);

if (socket.IsClosed)
return;

EndianReader reader = new(Endianness.BigEndian, payload.Span);
session.HandleMessage(socket, header, ref reader);
}
catch (IOException)
{
break;
}
catch (Exception ex)
{
if (socket.IsClosed)
return;

if (session != null)
_logger.ExceptionInSession(ex, session.SessionId.AsNumber());
else
_logger.ExceptionInReceiveLoop(ex, socket.TransportType);

break;
}
} while (!socket.IsClosed);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using ShortDev.Microsoft.ConnectedDevices.Transports;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;

namespace ShortDev.Microsoft.ConnectedDevices;

partial class ConnectedDevicesPlatform
{
readonly ConcurrentDictionary<EndpointInfo, CdpSocket> _knownSockets = new();

void RegisterKnownSocket(CdpSocket socket)
{
Debug.Assert(!socket.IsClosed);

socket.Disposed += OnSocketClosed;
void OnSocketClosed()
{
socket.Disposed -= OnSocketClosed;

var couldRemove = _knownSockets.TryRemove(KeyValuePair.Create(socket.Endpoint, socket));
Debug.Assert(couldRemove);
}

_knownSockets.AddOrUpdate(
socket.Endpoint,
static (key, newSocket) => newSocket,
static (key, newSocket, currentSocket) => newSocket,
socket
);
}

bool TryGetKnownSocket(EndpointInfo endpoint, [MaybeNullWhen(false)] out CdpSocket socket)
{
if (!_knownSockets.TryGetValue(endpoint, out socket))
return false;

// ToDo: Alive check!!
if (socket.IsClosed)
return false;

return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using ShortDev.Microsoft.ConnectedDevices.Transports;
using System.Collections.Concurrent;

namespace ShortDev.Microsoft.ConnectedDevices;

partial class ConnectedDevicesPlatform
{
readonly ConcurrentDictionary<CdpTransportType, ICdpTransport> _transportMap = new();
public void AddTransport<T>(T transport) where T : ICdpTransport
{
_transportMap.AddOrUpdate(
transport.TransportType,
static (key, newTansport) => newTansport,
(key, newTansport, oldTransport) =>
{
oldTransport.Dispose();
return newTansport;
},
transport
);
}

[Obsolete("Use overload instead")]
public T? TryGetTransport<T>() where T : ICdpTransport
=> (T?)_transportMap.Values.SingleOrDefault(x => x is T);

public ICdpTransport? TryGetTransport(CdpTransportType transportType)
=> _transportMap.GetValueOrDefault(transportType);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Microsoft.Extensions.Logging;
using ShortDev.Microsoft.ConnectedDevices.Encryption;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;

namespace ShortDev.Microsoft.ConnectedDevices;

partial class ConnectedDevicesPlatform
{
public static X509Certificate2 CreateDeviceCertificate([NotNull] CdpEncryptionParams encryptionParams)
{
using var key = ECDsa.Create(encryptionParams.Curve);
CertificateRequest certRequest = new("CN=Ms-Cdp", key, HashAlgorithmName.SHA256);
return certRequest.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
}

public static ILoggerFactory CreateLoggerFactory(string filePattern, LogLevel logLevel = LogLevel.Debug)
=> LoggerFactory.Create(builder =>
{
builder.ClearProviders();

builder.SetMinimumLevel(logLevel);

builder.AddFile(filePattern, logLevel);
});
}
129 changes: 2 additions & 127 deletions lib/ShortDev.Microsoft.ConnectedDevices/ConnectedDevicesPlatform.cs
Original file line number Diff line number Diff line change
@@ -1,40 +1,15 @@
using Microsoft.Extensions.Logging;
using ShortDev.Microsoft.ConnectedDevices.Encryption;
using ShortDev.Microsoft.ConnectedDevices.Messages;
using ShortDev.Microsoft.ConnectedDevices.Transports;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;

namespace ShortDev.Microsoft.ConnectedDevices;

public sealed class ConnectedDevicesPlatform(LocalDeviceInfo deviceInfo, ILoggerFactory loggerFactory) : IDisposable
public sealed partial class ConnectedDevicesPlatform(LocalDeviceInfo deviceInfo, ILoggerFactory loggerFactory) : IDisposable
{
public LocalDeviceInfo DeviceInfo { get; } = deviceInfo;

readonly ILogger<ConnectedDevicesPlatform> _logger = loggerFactory.CreateLogger<ConnectedDevicesPlatform>();

#region Transport
readonly ConcurrentDictionary<Type, ICdpTransport> _transportMap = new();
public void AddTransport<T>(T transport) where T : ICdpTransport
{
_transportMap.AddOrUpdate(typeof(T), transport, (_, old) =>
{
old.Dispose();
return transport;
});
}

public T? TryGetTransport<T>() where T : ICdpTransport
=> (T?)_transportMap.GetValueOrDefault(typeof(T));

public ICdpTransport? TryGetTransport(CdpTransportType transportType)
=> _transportMap.Values.SingleOrDefault(transport => transport.TransportType == transportType);
#endregion

#region Host
#region Advertise
public GuardFlag IsAdvertising { get; } = new();
Expand Down Expand Up @@ -151,7 +126,7 @@ internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint, Cancella
if (TryGetKnownSocket(endpoint, out var knownSocket))
return knownSocket;

var transport = TryGetTransport(endpoint.TransportType) ?? throw new InvalidOperationException($"No single transport found for type {endpoint.TransportType}");
var transport = TryGetTransport(endpoint.TransportType) ?? throw new InvalidOperationException($"No transport found for type {endpoint.TransportType}");
var socket = await transport.ConnectAsync(endpoint, cancellationToken).ConfigureAwait(false);
ReceiveLoop(socket);
return socket;
Expand All @@ -175,89 +150,6 @@ internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint, Cancella
}
#endregion

static readonly ArrayPool<byte> _messagePool = ArrayPool<byte>.Create();
private void ReceiveLoop(CdpSocket socket)
{
RegisterKnownSocket(socket);
Task.Run(() =>
{
EndianReader streamReader = new(Endianness.BigEndian, socket.InputStream);
using (socket)
{
do
{
CdpSession? session = null;
try
{
var header = CommonHeader.Parse(ref streamReader);

if (socket.IsClosed)
return;

session = CdpSession.GetOrCreate(
this,
socket.Endpoint,
header
);

using var payload = _messagePool.RentToken(header.PayloadSize);
streamReader.ReadBytes(payload.Span);

if (socket.IsClosed)
return;

EndianReader reader = new(Endianness.BigEndian, payload.Span);
session.HandleMessage(socket, header, ref reader);
}
catch (Exception ex)
{
if (socket.IsClosed)
return;

if (session != null)
_logger.ExceptionInSession(ex, session.SessionId.AsNumber());
else
_logger.ExceptionInReceiveLoop(ex, socket.TransportType);

break;
}
} while (!socket.IsClosed);
}
});
}

#region Socket Management
readonly ConcurrentDictionary<EndpointInfo, CdpSocket> _knownSockets = new();

void RegisterKnownSocket(CdpSocket socket)
{
socket.Disposed += OnSocketClosed;
void OnSocketClosed()
{
_knownSockets.TryRemove(socket.Endpoint, out _); // ToDo: We might remove a newer socket here!!
socket.Disposed -= OnSocketClosed;
}

_knownSockets.AddOrUpdate(socket.Endpoint, socket, (key, current) =>
{
// ToDo: Alive check
return socket;
});
}

bool TryGetKnownSocket(EndpointInfo endpoint, [MaybeNullWhen(false)] out CdpSocket socket)
{
if (!_knownSockets.TryGetValue(endpoint, out socket))
return false;

// ToDo: Alive check!!
if (socket.IsClosed)
return false;

return true;
}
#endregion

public CdpDeviceInfo GetCdpDeviceInfo()
{
List<EndpointInfo> endpoints = [];
Expand Down Expand Up @@ -286,21 +178,4 @@ public void Dispose()
_transportMap.Clear();
_knownSockets.Clear();
}

public static X509Certificate2 CreateDeviceCertificate([NotNull] CdpEncryptionParams encryptionParams)
{
using var key = ECDsa.Create(encryptionParams.Curve);
CertificateRequest certRequest = new("CN=Ms-Cdp", key, HashAlgorithmName.SHA256);
return certRequest.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
}

public static ILoggerFactory CreateLoggerFactory(string filePattern, LogLevel logLevel = LogLevel.Debug)
=> LoggerFactory.Create(builder =>
{
builder.ClearProviders();

builder.SetMinimumLevel(logLevel);

builder.AddFile(filePattern, logLevel);
});
}

0 comments on commit 3390897

Please sign in to comment.