Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop relying on UB in 'IContextCallback' dispatch logic #1865

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 51 additions & 26 deletions src/WinRT.Runtime/Interop/IContextCallback.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using WinRT;
using WinRT.Interop;

Expand All @@ -20,8 +22,11 @@ internal struct ComCallData
#if NET && CsWinRT_LANG_11_FEATURES
internal unsafe struct CallbackData
{
[ThreadStatic]
public static object PerThreadObject;

public delegate*<object, void> Callback;
public object State;
public object* StatePtr;
}
#endif

Expand All @@ -35,47 +40,67 @@ internal unsafe struct IContextCallbackVftbl

public static void ContextCallback(IntPtr contextCallbackPtr, delegate*<object, void> callback, delegate*<object, void> onFailCallback, object state)
{
ComCallData comCallData;
comCallData.dwDispid = 0;
comCallData.dwReserved = 0;

CallbackData callbackData;
callbackData.Callback = callback;
callbackData.State = state;

// We can just store a pointer to the callback to invoke in the context,
// so we don't need to allocate another closure or anything. The callback
// will be kept alive automatically, because 'comCallData' is address exposed.
// We only do this if we can use C# 11, and if we're on modern .NET, to be safe.
// In the callback below, we can then just retrieve the Action again to invoke it.
comCallData.pUserDefined = (IntPtr)(void*)&callbackData;

// Native method that invokes the callback on the target context. The state object
// is guaranteed to be pinned, so we can access it from a pointer. Note that the
// object will be stored in a static field, and it will not be on the stack of the
// original thread, so it's safe with respect to cross-thread access of managed objects.
// See: https://github.com/dotnet/runtime/blob/main/docs/design/specs/Memory-model.md#cross-thread-access-to-local-variables.
[UnmanagedCallersOnly]
static int InvokeCallback(ComCallData* comCallData)
{
try
{
CallbackData* callbackData = (CallbackData*)comCallData->pUserDefined;

callbackData->Callback(callbackData->State);
callbackData->Callback(*callbackData->StatePtr);

return 0; // S_OK
}
catch (Exception e)
{
return e.HResult;
}
}
}

// Store the state object in the thread static to pass to the callback.
// We don't need a volatile write here, we have a memory barrier below.
CallbackData.PerThreadObject = state;

int hresult;

// We use a thread local static field to efficiently store the state that's used by the callback. Note that this
// is safe with respect to reentrancy, as the target callback will never try to switch back on the original thread.
// We're only ever switching once on the original context, only to release the object reference that is passed as
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Release can run arbitrary code, is that right? I do not see what guarantees that the arbitrary code cannot switch threads at will.

Also, this seems to be used for more than just releasing the object reference:

Context.CallInContext(
_contextCallbackPtr,
_contextToken,
#if NET && CsWinRT_LANG_11_FEATURES
&InitAgileReference,
#else
InitAgileReference,
#endif
null,
this);

It would be more correct by construction to implement this as usual pool - clear the cached instance when it is rented, so that there is no way for one instance to be rented multiple times at the same time.

Copy link

@hamarb123 hamarb123 Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have assumed that IContextCallback::ContextCallback with ICallbackWithNoReentrancyToApplicationSTA means it shouldn't be able to run any arbitrary code on this thread as a result of that call (I'm not familiar with these APIs though, so I don't actually know), and nothing else from when we set the field to when we clear it seems problematic. If that thing doesn't guarantee this however, then something should be done to allow multiple current things at once potentially existing (e.g., using a List<object> or similar (probably manually via array for resizing purposes) would do the job).

Copy link

@hamarb123 hamarb123 Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do you mean the callback might call back into this thread & somehow end up back here? That would make sense I suppose. Is that possible @Sergio0694? Actually, I think in this case it would still be fine, as long as the callback is called before anything else silly happens, since it immediately reads the field & thus wouldn't get the new field value (assuming appropriate barrier or whatever - for which a volatile read would certainly be enough, but probably none is "needed"); then it would just be set to null twice, but the meaningful value of the field would already be read & that version would have already been read & had the correct value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"as long as the callback is called before anything else silly happens, since it immediately reads the field

That's actually a good point. Even if we somehow recursively ended up here (which I don't believe is possible), each callback would've already read the target state before invoking the user-provided callback anyway.

"Release can run arbitrary code"

Not really, I mean yes from the point of view of this API, but the context callback is only ever used internally to release IObjectReference objects, which are implemented in CsWinRT only. And we're only ever using these to pass the Release pointers which simply do a release on the tracker ref and the native object on the target context:

static void Release(object state)
{
ObjectReferenceWithContext<T> @this = Unsafe.As<ObjectReferenceWithContext<T>>(state);
@this.ReleaseFromBase();
}
static void ReleaseWithoutContext(object state)
{
ObjectReferenceWithContext<T> @this = Unsafe.As<ObjectReferenceWithContext<T>>(state);
@this.ReleaseWithoutContext();
}

// state. There is no way for that to possibly switch back on the starting thread. As such, using a thread static
// field to pass the state to the target context (we need to store it somewhere on the managed heap) is fine.
fixed (object* statePtr = &CallbackData.PerThreadObject)
{
CallbackData callbackData;
callbackData.Callback = callback;
callbackData.StatePtr = statePtr;

ComCallData comCallData;
comCallData.dwDispid = 0;
comCallData.dwReserved = 0;
comCallData.pUserDefined = (IntPtr)(void*)&callbackData;

Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA;
Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA;

// Add a memory barrier to be extra safe that the target thread will be able to see
// the write we just did on 'PerThreadObject' with the state to pass to the callback.
Thread.MemoryBarrier();

hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4(
contextCallbackPtr,
(IntPtr)(delegate* unmanaged<ComCallData*, int>)&InvokeCallback,
&comCallData,
&iid,
/* iMethod */ 5,
IntPtr.Zero);
}

int hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4(
contextCallbackPtr,
(IntPtr)(delegate* unmanaged<ComCallData*, int>)&InvokeCallback,
&comCallData,
&iid,
/* iMethod */ 5,
IntPtr.Zero);
// Reset the static field to avoid keeping the state alive for longer
Volatile.Write(ref CallbackData.PerThreadObject, null);

if (hresult < 0)
{
Expand Down