Skip to content

Commit

Permalink
Fix a few issues with ShellStream (#1322)
Browse files Browse the repository at this point in the history
* Fix a few issues with ShellStream

The main change is to replace the Queue<byte> with a byte[] and a couple of
variables which index the start and end of the data. The remainder is mainly
slightly more careful locking semantics.

It also implements Expect(string) separately so that it can work on the bytes
and skip a lot of encoding work (this is where I wish ShellStream derived from
StreamReader).

One possibly contentious point: in fixing the Write behaviour I chose to
remove the "outgoing" buffer and immediately send the data across the channel.
Write(string) and WriteLine(string) were already doing this, and I felt it
was better to change Write(byte[]) to match rather than changing the string
methods.

* Integrate expectSize (as "windowSize" parameter)

* Rename "windowSize" to "lookback"

---------

Co-authored-by: Wojciech Nagórski <[email protected]>
  • Loading branch information
Rob-Hague and WojciechNagorski authored Feb 16, 2024
1 parent 2b19eec commit 06af2ec
Show file tree
Hide file tree
Showing 29 changed files with 881 additions and 1,941 deletions.
38 changes: 38 additions & 0 deletions src/Renci.SshNet/Common/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,44 @@ public static byte[] TrimLeadingZeros(this byte[] value)
return value;
}

#if NETFRAMEWORK || NETSTANDARD2_0
public static int IndexOf(this byte[] array, byte[] value, int startIndex, int count)
{
if (value.Length > count)
{
return -1;
}

if (value.Length == 0)
{
return 0;
}

for (var i = startIndex; i < startIndex + count - value.Length + 1; i++)
{
if (MatchesAtIndex(i))
{
return i - startIndex;
}
}

return -1;

bool MatchesAtIndex(int i)
{
for (var j = 0; j < value.Length; j++)
{
if (array[i + j] != value[j])
{
return false;
}
}

return true;
}
}
#endif

/// <summary>
/// Pads with leading zeros if needed.
/// </summary>
Expand Down
178 changes: 178 additions & 0 deletions src/Renci.SshNet/Common/TaskToAsyncResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#pragma warning disable
#if !NET8_0_OR_GREATER
// Copied verbatim from https://github.com/dotnet/runtime/blob/78bd7debe6d8b454294c673c9cb969c6b8a14692/src/libraries/Common/src/System/Threading/Tasks/TaskToAsyncResult.cs

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;

namespace System.Threading.Tasks
{
/// <summary>
/// Provides methods for using <see cref="Task"/> to implement the Asynchronous Programming Model
/// pattern based on "Begin" and "End" methods.
/// </summary>
#if SYSTEM_PRIVATE_CORELIB
public
#else
internal
#endif
static class TaskToAsyncResult
{
/// <summary>Creates a new <see cref="IAsyncResult"/> from the specified <see cref="Task"/>, optionally invoking <paramref name="callback"/> when the task has completed.</summary>
/// <param name="task">The <see cref="Task"/> to be wrapped in an <see cref="IAsyncResult"/>.</param>
/// <param name="callback">The callback to be invoked upon <paramref name="task"/>'s completion. If <see langword="null"/>, no callback will be invoked.</param>
/// <param name="state">The state to be stored in the <see cref="IAsyncResult"/>.</param>
/// <returns>An <see cref="IAsyncResult"/> to represent the task's asynchronous operation. This instance will also be passed to <paramref name="callback"/> when it's invoked.</returns>
/// <exception cref="ArgumentNullException"><paramref name="task"/> is null.</exception>
/// <remarks>
/// In conjunction with the <see cref="End(IAsyncResult)"/> or <see cref="End{TResult}(IAsyncResult)"/> methods, this method may be used
/// to implement the Begin/End pattern (also known as the Asynchronous Programming Model pattern, or APM). It is recommended to not expose this pattern
/// in new code; the methods on <see cref="TaskToAsyncResult"/> are intended only to help implement such Begin/End methods when they must be exposed, for example
/// because a base class provides virtual methods for the pattern, or when they've already been exposed and must remain for compatibility. These methods enable
/// implementing all of the core asynchronous logic via <see cref="Task"/>s and then easily implementing Begin/End methods around that functionality.
/// </remarks>
public static IAsyncResult Begin(Task task, AsyncCallback? callback, object? state)
{
#if NET6_0_OR_GREATER
ArgumentNullException.ThrowIfNull(task);
#else
if (task is null)
{
throw new ArgumentNullException(nameof(task));
}
#endif

return new TaskAsyncResult(task, state, callback);
}

/// <summary>Waits for the <see cref="Task"/> wrapped by the <see cref="IAsyncResult"/> returned by <see cref="Begin"/> to complete.</summary>
/// <param name="asyncResult">The <see cref="IAsyncResult"/> for which to wait.</param>
/// <exception cref="ArgumentNullException"><paramref name="asyncResult"/> is null.</exception>
/// <exception cref="ArgumentException"><paramref name="asyncResult"/> was not produced by a call to <see cref="Begin"/>.</exception>
/// <remarks>This will propagate any exception stored in the wrapped <see cref="Task"/>.</remarks>
public static void End(IAsyncResult asyncResult) =>
Unwrap(asyncResult).GetAwaiter().GetResult();

/// <summary>Waits for the <see cref="Task{TResult}"/> wrapped by the <see cref="IAsyncResult"/> returned by <see cref="Begin"/> to complete.</summary>
/// <typeparam name="TResult">The type of the result produced.</typeparam>
/// <param name="asyncResult">The <see cref="IAsyncResult"/> for which to wait.</param>
/// <returns>The result of the <see cref="Task{TResult}"/> wrapped by the <see cref="IAsyncResult"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="asyncResult"/> is null.</exception>
/// <exception cref="ArgumentException"><paramref name="asyncResult"/> was not produced by a call to <see cref="Begin"/>.</exception>
/// <remarks>This will propagate any exception stored in the wrapped <see cref="Task{TResult}"/>.</remarks>
public static TResult End<TResult>(IAsyncResult asyncResult) =>
Unwrap<TResult>(asyncResult).GetAwaiter().GetResult();

/// <summary>Extracts the underlying <see cref="Task"/> from an <see cref="IAsyncResult"/> created by <see cref="Begin"/>.</summary>
/// <param name="asyncResult">The <see cref="IAsyncResult"/> created by <see cref="Begin"/>.</param>
/// <returns>The <see cref="Task"/> wrapped by the <see cref="IAsyncResult"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="asyncResult"/> is null.</exception>
/// <exception cref="ArgumentException"><paramref name="asyncResult"/> was not produced by a call to <see cref="Begin"/>.</exception>
public static Task Unwrap(IAsyncResult asyncResult)
{
#if NET6_0_OR_GREATER
ArgumentNullException.ThrowIfNull(asyncResult);
#else
if (asyncResult is null)
{
throw new ArgumentNullException(nameof(asyncResult));
}
#endif

if ((asyncResult as TaskAsyncResult)?._task is not Task task)
{
throw new ArgumentException(null, nameof(asyncResult));
}

return task;
}

/// <summary>Extracts the underlying <see cref="Task{TResult}"/> from an <see cref="IAsyncResult"/> created by <see cref="Begin"/>.</summary>
/// <typeparam name="TResult">The type of the result produced by the returned task.</typeparam>
/// <param name="asyncResult">The <see cref="IAsyncResult"/> created by <see cref="Begin"/>.</param>
/// <returns>The <see cref="Task{TResult}"/> wrapped by the <see cref="IAsyncResult"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="asyncResult"/> is null.</exception>
/// <exception cref="ArgumentException">
/// <paramref name="asyncResult"/> was not produced by a call to <see cref="Begin"/>,
/// or the <see cref="Task{TResult}"/> provided to <see cref="Begin"/> was used a generic type parameter
/// that's different from the <typeparamref name="TResult"/> supplied to this call.
/// </exception>
public static Task<TResult> Unwrap<TResult>(IAsyncResult asyncResult)
{
#if NET6_0_OR_GREATER
ArgumentNullException.ThrowIfNull(asyncResult);
#else
if (asyncResult is null)
{
throw new ArgumentNullException(nameof(asyncResult));
}
#endif

if ((asyncResult as TaskAsyncResult)?._task is not Task<TResult> task)
{
throw new ArgumentException(null, nameof(asyncResult));
}

return task;
}

/// <summary>Provides a simple <see cref="IAsyncResult"/> that wraps a <see cref="Task"/>.</summary>
/// <remarks>
/// We could use the Task as the IAsyncResult if the Task's AsyncState is the same as the object state,
/// but that's very rare, in particular in a situation where someone cares about allocation, and always
/// using TaskAsyncResult simplifies things and enables additional optimizations.
/// </remarks>
private sealed class TaskAsyncResult : IAsyncResult
{
/// <summary>The wrapped Task.</summary>
internal readonly Task _task;
/// <summary>Callback to invoke when the wrapped task completes.</summary>
private readonly AsyncCallback? _callback;

/// <summary>Initializes the IAsyncResult with the Task to wrap and the associated object state.</summary>
/// <param name="task">The Task to wrap.</param>
/// <param name="state">The new AsyncState value.</param>
/// <param name="callback">Callback to invoke when the wrapped task completes.</param>
internal TaskAsyncResult(Task task, object? state, AsyncCallback? callback)
{
Debug.Assert(task is not null);

_task = task;
AsyncState = state;

if (task.IsCompleted)
{
// The task has already completed. Treat this as synchronous completion.
// Invoke the callback; no need to store it.
CompletedSynchronously = true;
callback?.Invoke(this);
}
else if (callback is not null)
{
// Asynchronous completion, and we have a callback; schedule it. We use OnCompleted rather than ContinueWith in
// order to avoid running synchronously if the task has already completed by the time we get here but still run
// synchronously as part of the task's completion if the task completes after (the more common case).
_callback = callback;
_task.ConfigureAwait(continueOnCapturedContext: false)
.GetAwaiter()
.OnCompleted(() => _callback.Invoke(this));
}
}

/// <inheritdoc/>
public object? AsyncState { get; }

/// <inheritdoc/>
public bool CompletedSynchronously { get; }

/// <inheritdoc/>
public bool IsCompleted => _task.IsCompleted;

/// <inheritdoc/>
public WaitHandle AsyncWaitHandle => ((IAsyncResult) _task).AsyncWaitHandle;
}
}
}
#endif
22 changes: 0 additions & 22 deletions src/Renci.SshNet/ExpectAsyncResult.cs

This file was deleted.

4 changes: 1 addition & 3 deletions src/Renci.SshNet/IServiceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ internal partial interface IServiceFactory
/// <param name="height">The terminal height in pixels.</param>
/// <param name="terminalModeValues">The terminal mode values.</param>
/// <param name="bufferSize">Size of the buffer.</param>
/// <param name="expectSize">Size of the expect buffer.</param>
/// <returns>
/// The created <see cref="ShellStream"/> instance.
/// </returns>
Expand All @@ -136,8 +135,7 @@ ShellStream CreateShellStream(ISession session,
uint width,
uint height,
IDictionary<TerminalModes, uint> terminalModeValues,
int bufferSize,
int expectSize);
int bufferSize);

/// <summary>
/// Creates an <see cref="IRemotePathTransformation"/> that encloses a path in double quotes, and escapes
Expand Down
5 changes: 2 additions & 3 deletions src/Renci.SshNet/ServiceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ public ISftpResponseFactory CreateSftpResponseFactory()
/// <param name="height">The terminal height in pixels.</param>
/// <param name="terminalModeValues">The terminal mode values.</param>
/// <param name="bufferSize">The size of the buffer.</param>
/// <param name="expectSize">The size of the expect buffer.</param>
/// <returns>
/// The created <see cref="ShellStream"/> instance.
/// </returns>
Expand All @@ -202,9 +201,9 @@ public ISftpResponseFactory CreateSftpResponseFactory()
/// to the drawable area of the window.
/// </para>
/// </remarks>
public ShellStream CreateShellStream(ISession session, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModeValues, int bufferSize, int expectSize)
public ShellStream CreateShellStream(ISession session, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModeValues, int bufferSize)
{
return new ShellStream(session, terminalName, columns, rows, width, height, terminalModeValues, bufferSize, expectSize);
return new ShellStream(session, terminalName, columns, rows, width, height, terminalModeValues, bufferSize);
}

/// <summary>
Expand Down
Loading

0 comments on commit 06af2ec

Please sign in to comment.