Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1891,23 +1891,24 @@ internal static void SafeHandleRelease(SafeHandle pHandle)
private static extern IntPtr GetCOMIPFromRCW(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget);

[LibraryImport(RuntimeHelpers.QCall, EntryPoint = "StubHelpers_GetCOMIPFromRCWSlow")]
private static partial IntPtr GetCOMIPFromRCWSlow(ObjectHandleOnStack objSrc, IntPtr pCPCMD, out IntPtr ppTarget);
private static partial IntPtr GetCOMIPFromRCWSlow(ObjectHandleOnStack objSrc, IntPtr pCPCMD, out IntPtr ppTarget, [MarshalAs(UnmanagedType.Bool)] out bool pfNeedsRelease);

internal static IntPtr GetCOMIPFromRCW(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget, out bool pfNeedsRelease)
{
IntPtr rcw = GetCOMIPFromRCW(objSrc, pCPCMD, out ppTarget);
if (rcw == IntPtr.Zero)
if (rcw != IntPtr.Zero)
{
// If we didn't find the COM interface pointer in the cache we need to release the pointer.
pfNeedsRelease = true;
return GetCOMIPFromRCWWorker(objSrc, pCPCMD, out ppTarget);
pfNeedsRelease = false;
return rcw;
}
pfNeedsRelease = false;
return rcw;

// The slow path may create OLE TLS and then still resolve the interface via the RCW cache.
// Let the slow path tell us whether it returned an owned pointer that requires cleanup.
return GetCOMIPFromRCWWorker(objSrc, pCPCMD, out ppTarget, out pfNeedsRelease);

[MethodImpl(MethodImplOptions.NoInlining)]
static IntPtr GetCOMIPFromRCWWorker(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget)
=> GetCOMIPFromRCWSlow(ObjectHandleOnStack.Create(ref objSrc), pCPCMD, out ppTarget);
static IntPtr GetCOMIPFromRCWWorker(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget, out bool pfNeedsRelease)
=> GetCOMIPFromRCWSlow(ObjectHandleOnStack.Create(ref objSrc), pCPCMD, out ppTarget, out pfNeedsRelease);
}
#endif // FEATURE_COMINTEROP

Expand Down
23 changes: 14 additions & 9 deletions src/coreclr/vm/stubhelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,12 @@ FCIMPLEND

#include <optdefault.h>

extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget)
extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget, BOOL* pfNeedsRelease)
{
QCALL_CONTRACT;
_ASSERTE(pMD != NULL);
_ASSERTE(ppTarget != NULL);
_ASSERTE(pfNeedsRelease != NULL);

IUnknown *pIntf = NULL;
BEGIN_QCALL;
Expand All @@ -326,6 +327,8 @@ extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHand
OBJECTREF objRef = pSrc.Get();
GCPROTECT_BEGIN(objRef);

*pfNeedsRelease = FALSE;

// This snippet exists to enable OLE TLS data creation that isn't possible on the fast path.
// It is practically identical to the StubHelpers::GetCOMIPFromRCW FCALL, but in the event the OLE TLS
// data on this thread hasn't occurred yet, we will create it. Since this is the slow path, trying the
Expand All @@ -335,17 +338,19 @@ extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHand
RCW* pRCW = objRef->PassiveGetSyncBlock()->GetInteropInfoNoCreate()->GetRawRCW();
if (pRCW != NULL)
{
IUnknown* pUnk = GetCOMIPFromRCW_GetTargetFromRCWCache(pOleTlsData, pRCW, pComInfo, ppTarget);
if (pUnk != NULL)
return pUnk;
pIntf = GetCOMIPFromRCW_GetTargetFromRCWCache(pOleTlsData, pRCW, pComInfo, ppTarget);
}

// Still not in the cache and we've ensured the OLE TLS data was created.
SafeComHolder<IUnknown> pRetUnk = ComObject::GetComIPFromRCWThrowing(&objRef, pComInfo->m_pInterfaceMT);
*ppTarget = GetCOMIPFromRCW_GetTarget(pRetUnk, pComInfo);
_ASSERTE(*ppTarget != NULL);
if (pIntf == NULL)
{
// Still not in the cache and we've ensured the OLE TLS data was created.
SafeComHolder<IUnknown> pRetUnk = ComObject::GetComIPFromRCWThrowing(&objRef, pComInfo->m_pInterfaceMT);
*ppTarget = GetCOMIPFromRCW_GetTarget(pRetUnk, pComInfo);
_ASSERTE(*ppTarget != NULL);

pIntf = pRetUnk.Extract();
pIntf = pRetUnk.Extract();
*pfNeedsRelease = TRUE;
}

GCPROTECT_END();

Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/vm/stubhelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ extern "C" void QCALLTYPE StubHelpers_ProfilerEndTransitionCallback(MethodDesc*
#endif

#ifdef FEATURE_COMINTEROP
extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget);
extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget, BOOL* pfNeedsRelease);

extern "C" void QCALLTYPE ObjectMarshaler_ConvertToNative(QCall::ObjectHandleOnStack pSrcUNSAFE, VARIANT* pDest);
extern "C" void QCALLTYPE ObjectMarshaler_ConvertToManaged(VARIANT* pSrc, QCall::ObjectHandleOnStack retObject);
Expand Down
112 changes: 96 additions & 16 deletions src/tests/Interop/COM/NETClients/Lifetime/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ namespace NetClient
public unsafe class Program
{
static delegate* unmanaged<int> GetAllocationCount;
static ITrackMyLifetimeTesting? s_agileInstance;
static Exception? s_callbackException;

[DllImport("COMNativeServer", EntryPoint = "InvokeCallbackOnNativeThread")]
private static extern int InvokeCallbackOnNativeThread(delegate* unmanaged<void> callback);

// Initialize for all tests
[MethodImpl(MethodImplOptions.NoInlining)]
static void Initialize()
{
var inst = new TrackMyLifetimeTesting();
Expand All @@ -45,6 +51,19 @@ static void ForceGC()
}
}

[UnmanagedCallersOnly]
static void InvokeObjectFromNativeThread()
{
try
{
s_agileInstance!.Method();
}
catch (Exception e)
{
s_callbackException = e;
}
}

static void Validate_COMServer_CleanUp()
{
Console.WriteLine($"Calling {nameof(Validate_COMServer_CleanUp)}...");
Expand Down Expand Up @@ -85,36 +104,57 @@ static void Validate_COMServer_DisableEagerCleanUp()
Assert.False(Marshal.AreComObjectsAvailableForCleanup());
}

[Fact]
public static int TestEntryPoint()
static void Validate_COMServer_CallOnNativeThread()
{
// RegFree COM and STA apartments are not supported on Windows Nano
if (Utilities.IsWindowsNanoServer)
Console.WriteLine($"Calling {nameof(Validate_COMServer_CallOnNativeThread)}...");

// Need agile instance since the object will be used on a different thread
// than the creating thread and we're on an STA thread.
s_agileInstance = CreateAgileInstance();
try
{
s_agileInstance.Method();

// Create a fresh native thread for each callback so the COM call runs before that thread
// has initialized the CLR's OLE TLS state.
for (int i = 0; i < 10; i++)
{
s_callbackException = null;

Marshal.ThrowExceptionForHR(InvokeCallbackOnNativeThread(&InvokeObjectFromNativeThread));

Assert.True(s_callbackException is null, s_callbackException?.ToString());
}
}
finally
{
return 100;
s_agileInstance = null;
}

int result = 101;
[MethodImpl(MethodImplOptions.NoInlining)]
static ITrackMyLifetimeTesting CreateAgileInstance()
=> new TrackMyLifetimeTesting().CreateAgileInstance();
}

const int TestFailed = 101;
const int TestPassed = 100;

static int RunOnSTAThread(Action action)
{
int result = TestFailed;

// Run the test on a new STA thread since Nano Server doesn't support the STA
// and as a result, the main application thread can't be made STA with the STAThread attribute
Thread staThread = new Thread(() =>
{
try
{
// Initialization for all future tests
Initialize();
Assert.True(GetAllocationCount != null);

Validate_COMServer_CleanUp();
Validate_COMServer_DisableEagerCleanUp();
action();
}
catch (Exception e)
{
Console.WriteLine($"Test Failure: {e}");
result = 101;
result = TestFailed;
}
result = 100;
result = TestPassed;
});

staThread.SetApartmentState(ApartmentState.STA);
Expand All @@ -123,5 +163,45 @@ public static int TestEntryPoint()

return result;
}

[Fact]
public static int TestEntryPoint()
{
// RegFree COM and STA apartments are not supported on Windows Nano
if (Utilities.IsWindowsNanoServer)
{
return TestPassed;
}

// Run the test on a new STA thread since Nano Server doesn't support the STA
// and as a result, the main application thread can't be made STA with the STAThread attribute
int result = RunOnSTAThread(() =>
{
// Initialization for all future tests
Initialize();
ForceGC();
Assert.True(GetAllocationCount != null);

Validate_COMServer_CleanUp();
Validate_COMServer_CallOnNativeThread();
});
if (result != TestPassed)
{
return result;
}

return RunOnSTAThread(() =>
{
// Initialization for all future tests
Initialize();
ForceGC();
Assert.True(GetAllocationCount != null);

// Manipulating the eager cleanup state cannot be changed once set,
// so we need to run this test on a separate thread after the first
// test validates that cleanup is working as expected with eager cleanup enabled.
Validate_COMServer_DisableEagerCleanUp();
});
}
}
}
3 changes: 2 additions & 1 deletion src/tests/Interop/COM/NativeServer/Exports.def
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
EXPORTS
DllGetClassObject PRIVATE
DllRegisterServer PRIVATE
DllUnregisterServer PRIVATE
DllUnregisterServer PRIVATE
InvokeCallbackOnNativeThread
15 changes: 15 additions & 0 deletions src/tests/Interop/COM/NativeServer/Servers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "stdafx.h"
#include "Servers.h"
#include <thread>

namespace
{
Expand Down Expand Up @@ -155,6 +156,20 @@ namespace
}
}

extern "C" HRESULT STDMETHODCALLTYPE InvokeCallbackOnNativeThread(void (STDMETHODCALLTYPE* callback)())
{
if (callback == nullptr)
return E_POINTER;

std::thread worker([callback]()
{
callback();
});

worker.join();
return S_OK;
}

STDAPI DllRegisterServer(void)
{
HRESULT hr;
Expand Down
32 changes: 31 additions & 1 deletion src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "Servers.h"

class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTesting
class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTesting, public IAgileObject
{
static std::atomic<uint32_t> _instanceCount;

Expand All @@ -14,8 +14,15 @@ class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTestin
return _instanceCount;
}

private:
const bool _isAgileInstance = false;

public:
TrackMyLifetimeTesting()
: TrackMyLifetimeTesting(false)
{ }
TrackMyLifetimeTesting(bool isAgileInstance)
: _isAgileInstance(isAgileInstance)
{
_instanceCount++;
}
Expand All @@ -34,11 +41,34 @@ class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTestin
return S_OK;
}

DEF_FUNC(CreateAgileInstance)(ITrackMyLifetimeTesting** agileInstance)
{
if (agileInstance == nullptr)
return E_POINTER;

*agileInstance = new TrackMyLifetimeTesting(/*isAgileInstance*/ true);
return S_OK;
}

DEF_FUNC(Method)()
{
return S_OK;
}

public: // IUnknown
STDMETHOD(QueryInterface)(
/* [in] */ REFIID riid,
/* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject)
{
if (_isAgileInstance)
{
if (riid == __uuidof(IAgileObject))
{
*ppvObject = static_cast<IAgileObject*>(this);
AddRef();
return S_OK;
}
}
return DoQueryInterface(riid, ppvObject, static_cast<ITrackMyLifetimeTesting *>(this));
}

Expand Down
2 changes: 2 additions & 0 deletions src/tests/Interop/COM/ServerContracts/Server.Contracts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ internal interface IInspectableTesting2
internal interface ITrackMyLifetimeTesting
{
IntPtr GetAllocationCountCallback();
ITrackMyLifetimeTesting CreateAgileInstance();
void Method();
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/tests/Interop/COM/ServerContracts/Server.Contracts.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ struct __declspec(uuid("57f396a1-58a0-425f-8807-9f938a534984"))
ITrackMyLifetimeTesting : IUnknown
{
virtual HRESULT STDMETHODCALLTYPE GetAllocationCountCallback(_Outptr_ void** fptr) = 0;
virtual HRESULT STDMETHODCALLTYPE CreateAgileInstance(ITrackMyLifetimeTesting** agileInstance) = 0;
virtual HRESULT STDMETHODCALLTYPE Method() = 0;
};

// IIDs for the below types are generated by the runtime.
Expand Down
Loading