From a87cb70da934e373032c8c59a92a8b412d505d09 Mon Sep 17 00:00:00 2001 From: yowl Date: Sun, 5 Apr 2026 19:55:12 -0500 Subject: [PATCH 1/3] add future/stream cancel support Clear up some debug use same package location for codegen and runtime tests. --- crates/csharp/src/AsyncSupport.cs | 446 +++++++++++++----- crates/csharp/src/csproj.rs | 14 +- crates/csharp/src/function.rs | 4 +- crates/csharp/src/interface.rs | 30 +- .../async/future-cancel-read/runner.cs | 35 ++ .../async/future-cancel-read/test.cs | 36 ++ .../async/pending-import/runner.cs | 3 - 7 files changed, 423 insertions(+), 145 deletions(-) create mode 100644 tests/runtime-async/async/future-cancel-read/runner.cs create mode 100644 tests/runtime-async/async/future-cancel-read/test.cs diff --git a/crates/csharp/src/AsyncSupport.cs b/crates/csharp/src/AsyncSupport.cs index 0c4e47b3b..560581e17 100644 --- a/crates/csharp/src/AsyncSupport.cs +++ b/crates/csharp/src/AsyncSupport.cs @@ -21,6 +21,13 @@ public enum CallbackCode : uint //#define TEST_CALLBACK_CODE_WAIT(set) (2 | (set << 4)) } +public enum CancelCode : uint +{ + Completed = 0, + Dropped = 1, + Cancelled = 2, +} + // The context that we will create in unmanaged memory and pass to context_set. // TODO: C has world specific types for these pointers, perhaps C# would benefit from those also. [StructLayout(LayoutKind.Sequential)] @@ -65,9 +72,7 @@ private static class Interop public static int WaitableSetNew() { - var waitableSet = Interop.WaitableSetNew(); - Console.WriteLine($"WaitableSet created with number {waitableSet}"); - return waitableSet; + return Interop.WaitableSetNew(); } // unsafe because we are using pointers. @@ -80,12 +85,6 @@ public static unsafe void WaitableSetPoll(int waitableHandle) } } - internal static void Join(SubtaskStatus subtask, int waitableSetHandle, WaitableInfoState waitableInfoState) - { - AddTaskToWaitables(waitableSetHandle, subtask.Handle, waitableInfoState); - Interop.WaitableJoin(subtask.Handle, waitableSetHandle); - } - internal static void Join(int readerWriterHandle, int waitableHandle, WaitableInfoState waitableInfoState) { AddTaskToWaitables(waitableHandle, readerWriterHandle, waitableInfoState); @@ -101,7 +100,6 @@ public static void Join(int handle) private static void AddTaskToWaitables(int waitableSetHandle, int waitableHandle, WaitableInfoState waitableInfoState) { - Console.WriteLine($"Adding waitable {waitableHandle} to set {waitableSetHandle}"); var waitableSetOfTasks = pendingTasks.GetOrAdd(waitableSetHandle, _ => new ConcurrentDictionary()); waitableSetOfTasks[waitableHandle] = waitableInfoState; } @@ -132,9 +130,8 @@ public static unsafe void ContextSet(ContextTask* contextTask) } // unsafe because we are using pointers. - public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Action taskReturn) + public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr) { - Console.WriteLine($"Callback Event code {e.EventCode} Code {e.Code} Waitable {e.Waitable} Waitable Status {e.WaitableStatus.State}, Count {e.WaitableCount}"); ContextTask* contextTaskPtr = ContextGet(); var waitables = pendingTasks[contextTaskPtr->WaitableSetHandle]; @@ -142,14 +139,12 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Act if (e.IsDropped) { - Console.WriteLine("Dropped."); waitableInfoState.FutureStream.OtherSideDropped(); } if (e.IsCompleted || e.IsDropped) { // The operation is complete so we can free the buffer and remove the waitable from our dicitonary - Console.WriteLine("Setting the result"); waitables.Remove(e.Waitable, out _); if (e.IsSubtask) { @@ -173,14 +168,11 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Act if (waitables.Count == 0) { - Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {contextTaskPtr->WaitableSetHandle}"); - taskReturn(); ContextSet(null); Marshal.FreeHGlobal((IntPtr)contextTaskPtr); return (uint)CallbackCode.Exit; } - Console.WriteLine("More waitables in the set."); return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4); } @@ -193,23 +185,25 @@ internal static unsafe Task TaskFromStatus(uint status) var subtaskStatus = new SubtaskStatus(status); status = status & 0xF; + var tcs = new TaskCompletionSource(); if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted) { ContextTask* contextTaskPtr = ContextGet(); - if (contextTaskPtr == null) { + if (contextTaskPtr == null) + { contextTaskPtr = AllocateAndSetNewContext(); - Console.WriteLine($"TaskFromStatus creating WaitableSet {contextTaskPtr->WaitableSetHandle}"); } - TaskCompletionSource tcs = new TaskCompletionSource(); - Join(subtaskStatus, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs)); + Join(subtaskStatus.Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs)); + return tcs.Task; } else if (subtaskStatus.IsSubtaskReturned) { + tcs.SetResult(0); return Task.CompletedTask; } - else + else { throw new Exception($"unexpected subtask status: {status}"); } @@ -219,24 +213,23 @@ internal static unsafe Task TaskFromStatus(uint status) public static unsafe Task TaskFromStatus(uint status, Func liftFunc) { var subtaskStatus = new SubtaskStatus(status); - status = status & 0xF; - // TODO join and complete the task somwhere. - var tcs = new TaskCompletionSource(); if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted) { ContextTask* contextTaskPtr = ContextGet(); if (contextTaskPtr == null) { - contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf()); - Console.WriteLine("TaskFromStatus creating WaitableSet"); - contextTaskPtr->WaitableSetHandle = WaitableSetNew(); - ContextSet(contextTaskPtr); + contextTaskPtr = AllocateAndSetNewContext(); } + var intTaskCompletionSource = new TaskCompletionSource(); + var tcs = new LiftingTaskCompletionSource(intTaskCompletionSource, liftFunc); + Join(subtaskStatus.Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(intTaskCompletionSource)); + return tcs.Task; } else if (subtaskStatus.IsSubtaskReturned) { + var tcs = new TaskCompletionSource(); tcs.SetResult(liftFunc()); return tcs.Task; } @@ -246,12 +239,23 @@ public static unsafe Task TaskFromStatus(uint status, Func liftFunc) } } + // Placeholder, TODO: Needs implementing for async functions that return values. + internal class LiftingTaskCompletionSource : TaskCompletionSource + { + internal LiftingTaskCompletionSource(TaskCompletionSource innerTaskCompletionSource, Func _liftFunc) + { + innerTaskCompletionSource.Task.ContinueWith(t => { + throw new NotImplementedException("lifting results from async functions not implemented yet"); + }); + } + } + // unsafe because we are working with native memory. internal static unsafe ContextTask* AllocateAndSetNewContext() { var contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf()); - contextTaskPtr->WaitableSetHandle = AsyncSupport.WaitableSetNew(); - AsyncSupport.ContextSet(contextTaskPtr); + contextTaskPtr->WaitableSetHandle = WaitableSetNew(); + ContextSet(contextTaskPtr); return contextTaskPtr; } } @@ -269,24 +273,82 @@ public static unsafe Task TaskFromStatus(uint status, Func liftFunc) public delegate uint StreamWrite(int handle, IntPtr buffer, uint length); public delegate uint StreamRead(int handle, IntPtr buffer, uint length); public delegate void Lower(object payload, uint size); +public delegate uint CancelRead(int handle); +public delegate uint CancelWrite(int handle); + +public interface ICancelableRead +{ + uint CancelRead(int handle); +} -public struct FutureVTable +public interface ICancelableWrite { - public New New; - public FutureRead Read; - public FutureWrite Write; - public DropReader DropReader; - public DropWriter DropWriter; + uint CancelWrite(int handle); } -public struct StreamVTable +public interface ICancelable { - public New New; - public StreamRead Read; - public StreamWrite Write; - public DropReader DropReader; - public DropWriter DropWriter; - public Lower? Lower; + uint Cancel(); +} + +public class CancelableRead(ICancelableRead cancelableVTable, int handle) : ICancelable +{ + public uint Cancel() + { + return cancelableVTable.CancelRead(handle); + } +} + +public class CancelableWrite(ICancelableWrite cancelableVTable, int handle) : ICancelable +{ + public uint Cancel() + { + return cancelableVTable.CancelWrite(handle); + } +} + +public struct FutureVTable : ICancelableRead, ICancelableWrite +{ + internal New New; + internal FutureRead Read; + internal FutureWrite Write; + internal DropReader DropReader; + internal DropWriter DropWriter; + internal Lower? Lower; + internal CancelWrite CancelWriteDelegate; + internal CancelRead CancelReadDelegate; + + public uint CancelRead(int handle) + { + return CancelReadDelegate(handle); + } + + public uint CancelWrite(int handle) + { + return CancelWriteDelegate(handle); + } +} + +public struct StreamVTable : ICancelableRead, ICancelableWrite +{ + internal New New; + internal StreamRead Read; + internal StreamWrite Write; + internal DropReader DropReader; + internal DropWriter DropWriter; + internal Lower? Lower; + internal CancelWrite CancelWriteDelegate; + internal CancelRead CancelReadDelegate; + + public uint CancelRead(int handle) + { + return CancelReadDelegate(handle); + } + + public uint CancelWrite(int handle) + { + return CancelWriteDelegate(handle); + } } internal interface IFutureStream : IDisposable @@ -335,8 +397,6 @@ internal static (StreamReader, StreamWriter) RawStreamNew(StreamVTable var readerHandle = (int)(packed & 0xFFFFFFFF); var writerHandle = (int)(packed >> 32); - Console.WriteLine($"Creating reader with handle {readerHandle}"); - Console.WriteLine($"Creating writer with handle {writerHandle}"); return (new StreamReader(readerHandle, vtable), new StreamWriter(writerHandle, vtable)); } } @@ -364,10 +424,24 @@ internal int TakeHandle() return handle; } + protected GCHandle LiftBuffer(T[] buffer) + { + // For primitive, blittable types + if (typeof(T).IsPrimitive || typeof(T).IsValueType) + { + return GCHandle.Alloc(buffer, GCHandleType.Pinned); + } + else + { + // TODO: crete buffers for lowered stream types and then lift + throw new NotImplementedException("reading from futures types that require lifting"); + } + } + internal abstract uint VTableRead(IntPtr bufferPtr, int length); // unsafe as we are working with pointers. - internal unsafe Task ReadInternal(Func liftBuffer, int length) + internal unsafe ComponentTask ReadInternal(Func liftBuffer, int length, ICancelableRead cancelableRead) { if (Handle == 0) { @@ -383,23 +457,19 @@ internal unsafe Task ReadInternal(Func liftBuffer, int length) var status = new WaitableStatus(VTableRead(bufferHandle == null ? IntPtr.Zero : bufferHandle.Value.AddrOfPinnedObject(), length)); if (status.IsBlocked) { - Console.WriteLine("Read Blocked"); - var tcs = new TaskCompletionSource(); + var task = new ComponentTask(new CancelableRead(cancelableRead, Handle)); ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); if(contextTaskPtr == null) { - Console.WriteLine("FutureReader Read Blocked creating WaitableSet"); contextTaskPtr = AsyncSupport.AllocateAndSetNewContext(); } - Console.WriteLine("blocked read before join"); - AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this)); - Console.WriteLine("blocked read after join"); - return tcs.Task; + AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(task, this)); + return task; } if (status.IsCompleted) { - return Task.FromResult((int)status.Count); + return ComponentTask.FromResult((int)status.Count); } throw new NotImplementedException(status.State.ToString()); @@ -447,9 +517,9 @@ internal FutureReader(int handle, FutureVTable vTable) : base(handle) internal FutureVTable VTable { get; private set; } - public Task Read() + public ComponentTask Read() { - return ReadInternal(() => null, 0); + return ReadInternal(() => null, 0, VTable); } internal override uint VTableRead(IntPtr ptr, int length) @@ -467,22 +537,47 @@ public class FutureReader(int handle, FutureVTable vTable) : ReaderBase(handl { public FutureVTable VTable { get; private set; } = vTable; - private GCHandle LiftBuffer(T buffer) + public ComponentTask Read() { - if(typeof(T) == typeof(byte)) - { - return GCHandle.Alloc(buffer, GCHandleType.Pinned); - } - else + T[] buf = new T[1]; + ComponentTask internalTask = ReadInternal(() => LiftBuffer(buf), 1, VTable); + + // Wrap the task so we can return a T and not the number of Ts read + ComponentTask readTask = new(new DelegatingCancelable(internalTask)); + + internalTask.ContinueWith(it => { - // TODO: crete buffers for lowered stream types and then lift - throw new NotImplementedException("reading from futures types that require lifting"); - } + if (it.IsCompletedSuccessfully) + { + readTask.SetResult(buf[0]); + } + else if (it.IsCanceled) + { + // readTask.SetCanceled(); + } + else + { + //TODO + throw new NotImplementedException("faulted future read not implemented"); + } + }); + return readTask; } - public Task Read(T buffer) + class DelegatingCancelable : ICancelable { - return ReadInternal(() => LiftBuffer(buffer), 1); + private ComponentTask innerTask; + + internal DelegatingCancelable(ComponentTask innerTask) + { + this.innerTask = innerTask; + } + + uint ICancelable.Cancel() + { + var cancelVal = innerTask.Cancel(); + return (uint)cancelVal; + } } internal override uint VTableRead(IntPtr ptr, int length) @@ -505,9 +600,9 @@ public StreamReader(int handle, StreamVTable vTable) : base(handle) public StreamVTable VTable { get; private set; } - public Task Read(int length) + public ComponentTask Read(int length) { - return ReadInternal(() => null, length); + return ReadInternal(() => null, length, VTable); } internal override uint VTableRead(IntPtr ptr, int length) @@ -525,22 +620,9 @@ public class StreamReader(int handle, StreamVTable vTable) : ReaderBase(hand { public StreamVTable VTable { get; private set; } = vTable; - private GCHandle LiftBuffer(T[] buffer) - { - if(typeof(T) == typeof(byte)) - { - return GCHandle.Alloc(buffer, GCHandleType.Pinned); - } - else - { - // TODO: crete buffers for lowered stream types and then lift - throw new NotImplementedException("reading from stream types that require lifting"); - } - } - - public Task Read(T[] buffer) + public ComponentTask Read(T[] buffer) { - return ReadInternal(() => LiftBuffer(buffer), buffer.Length); + return ReadInternal(() => LiftBuffer(buffer), buffer.Length, VTable); } internal override uint VTableRead(IntPtr ptr, int length) @@ -558,6 +640,7 @@ public abstract class WriterBase : IFutureStream { private GCHandle? bufferHandle; private bool readerDropped; + private bool canDrop; internal WriterBase(int handle) { @@ -580,7 +663,7 @@ internal int TakeHandle() internal abstract uint VTableWrite(IntPtr bufferPtr, int length); // unsafe as we are working with pointers. - internal unsafe Task WriteInternal(Func lowerPayload, int length) + internal unsafe ComponentTask WriteInternal(Func lowerPayload, int length, ICancelableWrite cancelable) { if (Handle == 0) { @@ -594,25 +677,23 @@ internal unsafe Task WriteInternal(Func lowerPayload, int length bufferHandle = lowerPayload(); var status = new WaitableStatus(VTableWrite(bufferHandle == null ? IntPtr.Zero : bufferHandle.Value.AddrOfPinnedObject(), length)); + canDrop = true; // We can only call drop once something has been written. if (status.IsBlocked) { - Console.WriteLine("blocked write"); - var tcs = new TaskCompletionSource(); + var tcs = new ComponentTask(new CancelableWrite(cancelable, Handle)); ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); if(contextTaskPtr == null) { contextTaskPtr = AsyncSupport.AllocateAndSetNewContext(); } - Console.WriteLine("blocked write before join"); AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this)); - Console.WriteLine("blocked write after join"); - return tcs.Task; + return tcs; } if (status.IsCompleted) { bufferHandle?.Free(); - return Task.FromResult((int)status.Count); + return ComponentTask.FromResult((int)status.Count); } throw new NotImplementedException($"Unsupported write status {status.State}"); @@ -633,7 +714,7 @@ void IFutureStream.OtherSideDropped() void Dispose(bool _disposing) { // Free unmanaged resources if any. - if (Handle != 0) + if (Handle != 0 && canDrop) { VTableDrop(); } @@ -655,9 +736,9 @@ public class FutureWriter(int handle, FutureVTable vTable) : WriterBase(handle) { public FutureVTable VTable { get; private set; } = vTable; - public Task Write() + public ComponentTask Write() { - return WriteInternal(() => null, 0); + return WriteInternal(() => null, 0, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -675,10 +756,23 @@ public class FutureWriter(int handle, FutureVTable vTable) : WriterBase(handl { public FutureVTable VTable { get; private set; } = vTable; - // TODO: Generate per type for this instrinsic. - public Task Write() + private GCHandle LowerPayload(T[] payload) + { + if (VTable.Lower == null) + { + return GCHandle.Alloc(payload, GCHandleType.Pinned); + } + else + { + // Lower the payload + throw new NotSupportedException("StreamWriter Write where the payload must be lowered."); + // var loweredPayload = VTable.Lower(payload); + } + } + + public ComponentTask Write(T payload) { - return WriteInternal(() => null, 1); + return WriteInternal(() => LowerPayload([payload]), 1, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -696,9 +790,9 @@ public class StreamWriter(int handle, StreamVTable vTable) : WriterBase(handle) { public StreamVTable VTable { get; private set; } = vTable; - public Task Write() + public ComponentTask Write() { - return WriteInternal(() => null, 0); + return WriteInternal(() => null, 0, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -731,9 +825,9 @@ private GCHandle LowerPayload(T[] payload) } } - public Task Write(T[] payload) + public ComponentTask Write(T[] payload) { - return WriteInternal(() => LowerPayload(payload), payload.Length); + return WriteInternal(() => LowerPayload(payload), payload.Length, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -749,19 +843,13 @@ internal override void VTableDrop() internal struct WaitableInfoState { - internal WaitableInfoState(TaskCompletionSource taskCompletionSource, IFutureStream futureStream) + internal WaitableInfoState(ComponentTask componentTaskInt, IFutureStream futureStream) { - taskCompletionSourceInt = taskCompletionSource; + this.componentTaskInt = componentTaskInt; FutureStream = futureStream; } - internal WaitableInfoState(TaskCompletionSource taskCompletionSource, IFutureStream futureStream) - { - this.taskCompletionSource = taskCompletionSource; - FutureStream = futureStream; - } - - internal WaitableInfoState(TaskCompletionSource taskCompletionSource) + internal WaitableInfoState(TaskCompletionSource taskCompletionSource) { this.taskCompletionSource = taskCompletionSource; } @@ -770,31 +858,39 @@ internal void SetResult(int count) { if (taskCompletionSource != null) { - Console.WriteLine("Setting result for void waitable completion source"); - taskCompletionSource.SetResult(); + taskCompletionSource.SetResult(count); + } + else if (componentTask != null) + { + componentTask.SetResult(); + } + else if (componentTaskInt != null) + { + componentTaskInt.SetResult(count); } else { - taskCompletionSourceInt.SetResult(count); + throw new InvalidOperationException("No component task associated with this WaitableInfoState."); } } internal void SetException(Exception e) { - if (taskCompletionSource != null) + if (componentTask != null) { - Console.WriteLine("Setting exception waitable completion source"); - taskCompletionSource.SetException(e); + componentTask.SetException(e); } else { - taskCompletionSourceInt.SetException(e); + componentTaskInt.SetException(e); } } - private TaskCompletionSource taskCompletionSource; - private TaskCompletionSource taskCompletionSourceInt; - internal IFutureStream FutureStream; + // We have a taskCompletionSource for an async function, a ComponentTask for a future or stream. + private TaskCompletionSource? taskCompletionSource; + private ComponentTask? componentTask; + private ComponentTask? componentTaskInt; + internal IFutureStream? FutureStream; } public class StreamDroppedException : Exception @@ -807,3 +903,111 @@ public StreamDroppedException(string message) : base(message) { } } + +public abstract class ComponentTask +{ + protected readonly ICancelable cancelableVTable; + private bool canCancel = true; + + internal ComponentTask(ICancelable? cancelableVTable = null) + { + this.cancelableVTable = cancelableVTable; + } + + public abstract Task Task { get; } + + public abstract bool IsCompleted { get; } + + public CancelCode Cancel() + { + if(!canCancel) + { + return CancelCode.Completed; + } + + if(cancelableVTable == null) + { + throw new InvalidOperationException("Cannot cancel a task that was created as completed with a result."); + } + + uint cancelReturn = cancelableVTable.Cancel(); + SetCanceled(); + return (CancelCode)cancelReturn; + } + + public abstract void SetCanceled(); + + public virtual void SetResult() + { + canCancel = false; + } + + public abstract void SetException(Exception e); + + public static ComponentTask FromResult() + { + var task = new ComponentTask(); + task.SetResult(0); + return task; + } + + /// + /// Makes the class directly awaitable. + /// + public TaskAwaiter GetAwaiter() + { + return Task.GetAwaiter(); + } +} + +public class ComponentTask : ComponentTask +{ + private readonly TaskCompletionSource tcs; + + internal ComponentTask(ICancelable? cancelableVTable = null) : base(cancelableVTable) + { + tcs = new TaskCompletionSource(); + } + + public override Task Task => tcs.Task; + + public override bool IsCompleted => tcs.Task.IsCompleted; + + public Task ContinueWith(Action> continuationAction) + { + return tcs.Task.ContinueWith(continuationAction, TaskContinuationOptions.ExecuteSynchronously); + } + + public void SetResult(T result) + { + SetResult(); + tcs.SetResult(result); + } + + public static ComponentTask FromResult(T result) + { + var task = new ComponentTask(); + task.tcs.SetResult(result); + return task; + } + + public override void SetCanceled() + { + tcs.SetCanceled(); + } + + public override void SetException(Exception e) + { + tcs.SetException(e); + } + + /// + /// Makes the class directly awaitable. + /// + public new TaskAwaiter GetAwaiter() + { + return tcs.Task.GetAwaiter(); + } + + public T Result => tcs.Task.Result; +} \ No newline at end of file diff --git a/crates/csharp/src/csproj.rs b/crates/csharp/src/csproj.rs index 1d9d86aaa..76b8c5d57 100644 --- a/crates/csharp/src/csproj.rs +++ b/crates/csharp/src/csproj.rs @@ -106,6 +106,12 @@ impl CSProjectLLVMBuilder { other => todo!("OS {} not supported", other), }; + // Share nuget packages between codegen and runtime tests. + let packages_path = if self.dir.to_str().unwrap().contains("codegen") { + "../../../../.packages" + } else { + "../../../.packages" + }; csproj.push_str( &format!( r#" @@ -117,11 +123,12 @@ impl CSProjectLLVMBuilder { fs::write( self.dir.join("nuget.config"), - r#" + format!( + r#" - + @@ -130,7 +137,8 @@ impl CSProjectLLVMBuilder { - "#, + "# + ), )?; } diff --git a/crates/csharp/src/function.rs b/crates/csharp/src/function.rs index 8d7ab4a1e..693f6cf6d 100644 --- a/crates/csharp/src/function.rs +++ b/crates/csharp/src/function.rs @@ -1155,12 +1155,12 @@ impl Bindgen for FunctionBindgen<'_, '_> { if (t.IsFaulted) {{ // TODO - Console.Error.WriteLine("Async function {name} IsFaulted. This scenario is not yet implemented."); throw new NotImplementedException("Async function {name} IsFaulted. This scenario is not yet implemented."); }} {name}TaskReturn({ret_param}); - }}); + + }}, TaskContinuationOptions.ExecuteSynchronously); // TODO: Defer dropping borrowed resources until a result is returned. ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); diff --git a/crates/csharp/src/interface.rs b/crates/csharp/src/interface.rs index 49d55cd8c..ca09f01b9 100644 --- a/crates/csharp/src/interface.rs +++ b/crates/csharp/src/interface.rs @@ -230,12 +230,12 @@ impl InterfaceGenerator<'_> { let mut generated_future_types: HashSet> = HashSet::new(); let (_namespace, interface_name) = &CSharp::get_class_name_from_qualified_name(self.name); let interop_name = format!("{}Interop", interface_name.strip_prefix("I").unwrap()); - let (futures_or_streams, stream_length_param) = if is_future { (&self.futures, "") } else { (&self.streams, ", uint length") }; + let mut index = 0; for future in futures_or_streams { // This code originally copied from Rust codegen generate_payload. // See the rust codegen for the comment - essentially we canonicalize to one per type. @@ -268,6 +268,8 @@ impl InterfaceGenerator<'_> { Write = {future_stream_name}Write{upper_camel_future_type}, DropReader = {future_stream_name}DropReader{upper_camel_future_type}, DropWriter = {future_stream_name}DropWriter{upper_camel_future_type}, + CancelReadDelegate = {future_stream_name}CancelRead{upper_camel_future_type}, + CancelWriteDelegate = {future_stream_name}CancelWrite{upper_camel_future_type}, }}; "# ); @@ -275,16 +277,16 @@ impl InterfaceGenerator<'_> { uwrite!( self.csharp_interop_src, r#" - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[async-lower][{future_stream_name_lower}-read-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[async-lower][{future_stream_name_lower}-read-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static unsafe extern uint {future_stream_name}Read{upper_camel_future_type}(int readable, IntPtr ptr{stream_length_param}); - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[async-lower][{future_stream_name_lower}-write-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[async-lower][{future_stream_name_lower}-write-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static unsafe extern uint {future_stream_name}Write{upper_camel_future_type}(int writeable, IntPtr buffer{stream_length_param}); - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-drop-readable-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-drop-readable-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static extern void {future_stream_name}DropReader{upper_camel_future_type}(int readable); - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-drop-writable-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-drop-writable-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static extern void {future_stream_name}DropWriter{upper_camel_future_type}(int readable); "# ); @@ -292,7 +294,7 @@ impl InterfaceGenerator<'_> { uwrite!( self.csharp_interop_src, r#" - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-new-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-new-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static extern ulong {future_stream_name}New{upper_camel_future_type}(); "# ); @@ -301,7 +303,7 @@ impl InterfaceGenerator<'_> { uwrite!( self.csharp_interop_src, r#" - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-cancel-read-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-cancel-read-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static extern uint {future_stream_name}CancelRead{upper_camel_future_type}(int readable); "# ); @@ -309,7 +311,7 @@ impl InterfaceGenerator<'_> { uwrite!( self.csharp_interop_src, r#" - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-cancel-write-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-cancel-write-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static extern uint {future_stream_name}CancelWrite{upper_camel_future_type}(int writeable); "# ); @@ -317,7 +319,7 @@ impl InterfaceGenerator<'_> { uwrite!( self.csharp_interop_src, r#" - [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-drop-writeable-0]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + [global::System.Runtime.InteropServices.DllImportAttribute("{import_module_name}", EntryPoint = "[{future_stream_name_lower}-drop-writeable-{index}]{future_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static extern void {future_stream_name}DropWriteable{upper_camel_future_type}(int writeable); "# ); @@ -385,6 +387,7 @@ impl InterfaceGenerator<'_> { } generated_future_types.insert(canonical_payload); + index = index + 1; } self.csharp_gen.needs_async_support = true; @@ -699,9 +702,7 @@ var {async_status_var} = {raw_name}({wasm_params}); uwriteln!( src, " - Console.WriteLine(\"calling TaskFromStatus from func {}\"); var task = AsyncSupport.TaskFromStatus<{return_type}>({async_status_var}, {});", - func.name, lift_func ); uwriteln!( @@ -714,9 +715,7 @@ var {async_status_var} = {raw_name}({wasm_params}); uwriteln!( src, " - Console.WriteLine(\"calling TaskFromStatus from func {}\"); - return AsyncSupport.TaskFromStatus({async_status_var});", - func.name + return AsyncSupport.TaskFromStatus({async_status_var});" ); } } else { @@ -877,7 +876,6 @@ var {async_status_var} = {raw_name}({wasm_params}); [global::System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute(EntryPoint = "[callback]{export_name}")] public static unsafe uint {camel_name}Callback(int eventRaw, uint waitable, uint code) {{ - Console.WriteLine($"Callback with code {{code}}"); EventWaitable e = new EventWaitable((EventCode)eventRaw, waitable, code); "# @@ -896,7 +894,7 @@ var {async_status_var} = {raw_name}({wasm_params}); uwriteln!( self.csharp_interop_src, r#" - return (uint)AsyncSupport.Callback(e, (ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn()); + return (uint)AsyncSupport.Callback(e, (ContextTask *)IntPtr.Zero); }} "# ); diff --git a/tests/runtime-async/async/future-cancel-read/runner.cs b/tests/runtime-async/async/future-cancel-read/runner.cs new file mode 100644 index 000000000..b14b778b8 --- /dev/null +++ b/tests/runtime-async/async/future-cancel-read/runner.cs @@ -0,0 +1,35 @@ +using System.Diagnostics; +using RunnerWorld.wit.Imports.my.test; +using RunnerWorld; + +public class RunnerWorldExportsImpl +{ + public static async Task Run() + { + { + var (reader, writer) = IIImports.FutureNewUint(); + await IIImports.CancelBeforeRead(reader); + writer.Dispose(); + } + + + { + var (reader, writer) = IIImports.FutureNewUint(); + await IIImports.CancelAfterRead(reader); + writer.Dispose(); + } + + { + var (dataReader, dataWriter) = IIImports.FutureNewUint(); + var (signalReader, signalWriter) = IIImports.FutureNew(); + var testTask = IIImports.StartReadThenCancel(dataReader, signalReader); + async Task WriterAsync() + { + await signalWriter.Write(); + await dataWriter.Write(4); + } + + await WriterAsync(); + } + } +} diff --git a/tests/runtime-async/async/future-cancel-read/test.cs b/tests/runtime-async/async/future-cancel-read/test.cs new file mode 100644 index 000000000..60485cc40 --- /dev/null +++ b/tests/runtime-async/async/future-cancel-read/test.cs @@ -0,0 +1,36 @@ +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading.Tasks; + +namespace TestWorld.wit.Exports.my.test +{ + public class IExportsImpl : IIExports + { + public static Task CancelBeforeRead(FutureReader future) + { + Debug.Assert(future.Read().Cancel() == CancelCode.Cancelled); + future.Dispose(); + return Task.CompletedTask; + } + + public static Task CancelAfterRead(FutureReader future) + { + var task = future.Read(); + Debug.Assert(!task.IsCompleted); + + // If the cancel occurs before the read is complete (or the writer ignores the cancel) we return Cancelled. + Debug.Assert(task.Cancel() == CancelCode.Cancelled); + return Task.CompletedTask; + } + + public static async Task StartReadThenCancel(FutureReader future, FutureReader signal) + { + var task = future.Read(); + Debug.Assert(!task.IsCompleted); + + await signal.Read(); + + Debug.Assert(task.Cancel() == CancelCode.Completed); + } + } +} \ No newline at end of file diff --git a/tests/runtime-async/async/pending-import/runner.cs b/tests/runtime-async/async/pending-import/runner.cs index 197f84130..1e2451082 100644 --- a/tests/runtime-async/async/pending-import/runner.cs +++ b/tests/runtime-async/async/pending-import/runner.cs @@ -11,11 +11,8 @@ public static async Task Run() var task = IIImports.PendingImport(reader); Debug.Assert(!task.IsCompleted); - Console.WriteLine("Writing to future to complete pending import..."); var writeTask = writer.Write(); - Console.WriteLine("WriteTask IsCompleted: " + writeTask.IsCompleted); await task; - Console.WriteLine("RunnerWorld PendingImport task is completed"); Debug.Assert(!task.IsFaulted && task.IsCompleted); writer.Dispose(); reader.Dispose(); From 984ca740bd0151932d92055453569f2c4c1d38c0 Mon Sep 17 00:00:00 2001 From: yowl Date: Mon, 6 Apr 2026 19:46:44 -0500 Subject: [PATCH 2/3] feedback - remove dead code Make sure writables are not dropped after cancellation. --- crates/csharp/src/AsyncSupport.cs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/crates/csharp/src/AsyncSupport.cs b/crates/csharp/src/AsyncSupport.cs index 560581e17..34e5cd872 100644 --- a/crates/csharp/src/AsyncSupport.cs +++ b/crates/csharp/src/AsyncSupport.cs @@ -551,11 +551,7 @@ public ComponentTask Read() { readTask.SetResult(buf[0]); } - else if (it.IsCanceled) - { - // readTask.SetCanceled(); - } - else + else if (!it.IsCanceled) { //TODO throw new NotImplementedException("faulted future read not implemented"); @@ -681,6 +677,14 @@ internal unsafe ComponentTask WriteInternal(Func lowerPayload, i if (status.IsBlocked) { var tcs = new ComponentTask(new CancelableWrite(cancelable, Handle)); + tcs.ContinueWith(t => + { + if (t.IsCanceled) + { + canDrop = false; + } + }); + ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); if(contextTaskPtr == null) { From 0b078708944a6ec45e8faef9ab9fa7a4ff32e3de Mon Sep 17 00:00:00 2001 From: yowl Date: Mon, 6 Apr 2026 19:47:17 -0500 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Joel Dice --- crates/csharp/src/AsyncSupport.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/csharp/src/AsyncSupport.cs b/crates/csharp/src/AsyncSupport.cs index 34e5cd872..46cdbc1bb 100644 --- a/crates/csharp/src/AsyncSupport.cs +++ b/crates/csharp/src/AsyncSupport.cs @@ -433,7 +433,7 @@ protected GCHandle LiftBuffer(T[] buffer) } else { - // TODO: crete buffers for lowered stream types and then lift + // TODO: create buffers for lowered stream types and then lift throw new NotImplementedException("reading from futures types that require lifting"); } }