diff --git a/src/WinRT.Runtime/Interop/IContextCallback.cs b/src/WinRT.Runtime/Interop/IContextCallback.cs index 9a87c29fa..3ca4c8de5 100644 --- a/src/WinRT.Runtime/Interop/IContextCallback.cs +++ b/src/WinRT.Runtime/Interop/IContextCallback.cs @@ -2,7 +2,9 @@ // Licensed under the MIT License. using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Threading; using WinRT; using WinRT.Interop; @@ -20,8 +22,11 @@ internal struct ComCallData #if NET && CsWinRT_LANG_11_FEATURES internal unsafe struct CallbackData { + [ThreadStatic] + public static object PerThreadObject; + public delegate* Callback; - public object State; + public object* StatePtr; } #endif @@ -35,21 +40,11 @@ internal unsafe struct IContextCallbackVftbl public static void ContextCallback(IntPtr contextCallbackPtr, delegate* callback, delegate* onFailCallback, object state) { - ComCallData comCallData; - comCallData.dwDispid = 0; - comCallData.dwReserved = 0; - - CallbackData callbackData; - callbackData.Callback = callback; - callbackData.State = state; - - // We can just store a pointer to the callback to invoke in the context, - // so we don't need to allocate another closure or anything. The callback - // will be kept alive automatically, because 'comCallData' is address exposed. - // We only do this if we can use C# 11, and if we're on modern .NET, to be safe. - // In the callback below, we can then just retrieve the Action again to invoke it. - comCallData.pUserDefined = (IntPtr)(void*)&callbackData; - + // Native method that invokes the callback on the target context. The state object + // is guaranteed to be pinned, so we can access it from a pointer. Note that the + // object will be stored in a static field, and it will not be on the stack of the + // original thread, so it's safe with respect to cross-thread access of managed objects. + // See: https://github.com/dotnet/runtime/blob/main/docs/design/specs/Memory-model.md#cross-thread-access-to-local-variables. [UnmanagedCallersOnly] static int InvokeCallback(ComCallData* comCallData) { @@ -57,7 +52,7 @@ static int InvokeCallback(ComCallData* comCallData) { CallbackData* callbackData = (CallbackData*)comCallData->pUserDefined; - callbackData->Callback(callbackData->State); + callbackData->Callback(*callbackData->StatePtr); return 0; // S_OK } @@ -65,17 +60,47 @@ static int InvokeCallback(ComCallData* comCallData) { return e.HResult; } - } + } + + // Store the state object in the thread static to pass to the callback. + // We don't need a volatile write here, we have a memory barrier below. + CallbackData.PerThreadObject = state; + + int hresult; + + // We use a thread local static field to efficiently store the state that's used by the callback. Note that this + // is safe with respect to reentrancy, as the target callback will never try to switch back on the original thread. + // We're only ever switching once on the original context, only to release the object reference that is passed as + // state. There is no way for that to possibly switch back on the starting thread. As such, using a thread static + // field to pass the state to the target context (we need to store it somewhere on the managed heap) is fine. + fixed (object* statePtr = &CallbackData.PerThreadObject) + { + CallbackData callbackData; + callbackData.Callback = callback; + callbackData.StatePtr = statePtr; + + ComCallData comCallData; + comCallData.dwDispid = 0; + comCallData.dwReserved = 0; + comCallData.pUserDefined = (IntPtr)(void*)&callbackData; - Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA; + Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA; + + // Add a memory barrier to be extra safe that the target thread will be able to see + // the write we just did on 'PerThreadObject' with the state to pass to the callback. + Thread.MemoryBarrier(); + + hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4( + contextCallbackPtr, + (IntPtr)(delegate* unmanaged)&InvokeCallback, + &comCallData, + &iid, + /* iMethod */ 5, + IntPtr.Zero); + } - int hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4( - contextCallbackPtr, - (IntPtr)(delegate* unmanaged)&InvokeCallback, - &comCallData, - &iid, - /* iMethod */ 5, - IntPtr.Zero); + // Reset the static field to avoid keeping the state alive for longer + Volatile.Write(ref CallbackData.PerThreadObject, null); if (hresult < 0) {