Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic casting to COM implementation #3055

Merged
merged 1 commit into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions crates/libs/core/src/com_object.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::imp::Box;
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, InterfaceRef};
use core::any::Any;
use core::borrow::Borrow;
use core::ops::Deref;
use core::ptr::NonNull;
Expand Down Expand Up @@ -196,6 +197,28 @@ impl<T: ComObjectInner> ComObject<T> {
I::from_raw(raw)
}
}

/// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer"
/// object, e.g. `MyApp_Impl`, not the inner `MyApp` object.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// The returned value is an owned (counted) reference; this function calls `AddRef` on the
/// underlying COM object. If you do not need an owned reference, then you can use the
/// [`Interface::cast_object_ref`] method instead, and avoid the cost of `AddRef` / `Release`.
pub fn cast_from<I>(interface: &I) -> crate::Result<Self>
where
I: Interface,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
interface.cast_object()
}
}

impl<T: ComObjectInner + Default> Default for ComObject<T> {
Expand Down
128 changes: 126 additions & 2 deletions crates/libs/core/src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::*;
use core::any::Any;
use core::ffi::c_void;
use core::marker::PhantomData;
use core::mem::{forget, transmute_copy};
use core::mem::{forget, transmute_copy, MaybeUninit};
use core::ptr::NonNull;

/// Provides low-level access to an interface vtable.
Expand Down Expand Up @@ -97,7 +98,7 @@ pub unsafe trait Interface: Sized + Clone {
//
// This guards against implementations of COM interfaces which may store non-null values
// in 'result' but still return E_NOINTERFACE.
let mut result = core::mem::MaybeUninit::<Option<T>>::zeroed();
let mut result = MaybeUninit::<Option<T>>::zeroed();
self.query(&T::IID, result.as_mut_ptr() as _).ok()?;

// If we get here, then query() has succeeded, but we still need to double-check
Expand All @@ -110,6 +111,123 @@ pub unsafe trait Interface: Sized + Clone {
}
}

/// This casts the given COM interface to [`&dyn Any`].
///
/// Applications should generally _not_ call this method directly. Instead, use the
/// [`Interface::cast_object_ref`] or [`Interface::cast_object`] methods.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// # Safety
///
/// **IMPORTANT!!** This uses a non-standard protocol for QueryInterface! The `DYNAMIC_CAST_IID`
/// IID identifies this protocol, but there is no `IDynamicCast` interface. Instead, objects
/// that recognize `DYNAMIC_CAST_IID` simply store their `&dyn Any` directly at the interface
/// pointer that was passed to `QueryInterface. This means that the returned value has a
/// size that is twice as large (`size_of::<&dyn Any>() == 2 * size_of::<*const c_void>()`).
///
/// This means that callers that use this protocol cannot simply pass `&mut ptr` for
/// an ordinary single-pointer-sized pointer. Only this method understands this protocol.
///
/// Another part of this protocol is that the implementation of `QueryInterface` _does not_
/// AddRef the object. The caller must guarantee the liveness of the COM object. In Rust,
/// this means tying the lifetime of the IUnknown* that we used for the QueryInterface
/// call to the lifetime of the returned `&dyn Any` value.
///
/// This method preserves type safety and relies on these invariants:
///
/// * All `QueryInterface` implementations that recognize `DYNAMIC_CAST_IID` are generated by
/// the `#[implement]` macro and respect the rules described here.
#[inline(always)]
fn cast_to_any<T>(&self) -> Result<&dyn Any>
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
unsafe {
let mut any_ref_arg: MaybeUninit<&dyn Any> = MaybeUninit::zeroed();
self.query(&DYNAMIC_CAST_IID, any_ref_arg.as_mut_ptr() as *mut *mut c_void).ok()?;
Ok(any_ref_arg.assume_init())
}
}

/// Returns `true` if the given COM interface refers to an implementation of `T`.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `false`.
#[inline(always)]
fn is_object<T>(&self) -> bool
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
if let Ok(any) = self.cast_to_any::<T>() {
any.is::<T::Outer>()
} else {
false
}
}

/// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer"
/// object, e.g. `&MyApp_Impl`, not the inner `&MyApp` object.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// The returned value is borrowed. If you need an owned (counted) reference, then use
/// [`Interface::cast_object`].
#[inline(always)]
fn cast_object_ref<T>(&self) -> Result<&T::Outer>
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
let any: &dyn Any = self.cast_to_any::<T>()?;
if let Some(outer) = any.downcast_ref::<T::Outer>() {
Ok(outer)
} else {
Err(imp::E_NOINTERFACE.into())
}
}

/// This casts the given COM interface to [`&dyn Any`]. It returns a reference to the "outer"
/// object, e.g. `MyApp_Impl`, not the inner `MyApp` object.
///
/// `T` must be a type that has been annotated with `#[implement]`; this is checked at
/// compile-time by the generic constraints of this method. However, note that the
/// returned `&dyn Any` refers to the _outer_ implementation object that was generated by
/// `#[implement]`, i.e. the `MyApp_Impl` type, not the inner `MyApp` type.
///
/// If the given object is not a Rust object, or is a Rust object but not `T`, or is a Rust
/// object that contains non-static lifetimes, then this function will return `Err(E_NOINTERFACE)`.
///
/// The returned value is an owned (counted) reference; this function calls `AddRef` on the
/// underlying COM object. If you do not need an owned reference, then you can use the
/// [`Interface::cast_object_ref`] method instead, and avoid the cost of `AddRef` / `Release`.
#[inline(always)]
fn cast_object<T>(&self) -> Result<ComObject<T>>
where
T: ComObjectInner,
T::Outer: Any + 'static + IUnknownImpl<Impl = T>,
{
let object_ref = self.cast_object_ref::<T>()?;
Ok(object_ref.to_object())
}

/// Attempts to create a [`Weak`] reference to this object.
fn downgrade(&self) -> Result<Weak<Self>> {
self.cast::<imp::IWeakReferenceSource>().and_then(|source| Weak::downgrade(&source))
Expand Down Expand Up @@ -210,3 +328,9 @@ impl<'a, I: Interface> core::ops::Deref for InterfaceRef<'a, I> {
unsafe { core::mem::transmute(self) }
}
}

/// This IID identifies a special protocol, used by [`Interface::cast_to_any`]. This is _not_
/// an ordinary COM interface; it uses special lifetime rules and a larger interface pointer.
/// See the comments on [`Interface::cast_to_any`].
#[doc(hidden)]
pub const DYNAMIC_CAST_IID: GUID = GUID::from_u128(0xae49d5cb_143f_431c_874c_2729336e4eca);
15 changes: 15 additions & 0 deletions crates/libs/core/src/unknown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ pub trait IUnknownImpl {
{
<Self as ComObjectInterface<I>>::as_interface_ref(self).to_owned()
}

/// Creates a new owned reference to this object.
///
/// # Safety
///
/// This function can only be safely called by `<Foo>_Impl` objects that are embedded in a
/// `ComObject`. Since we only allow safe Rust code to access these objects using a `ComObject`
/// or a `&<Foo>_Impl` that points within a `ComObject`, this is safe.
fn to_object(&self) -> ComObject<Self::Impl>
where
Self::Impl: ComObjectInner<Outer = Self>;

/// The distance from the start of `<Foo>_Impl` to the `this` field within it, measured in
/// pointer-sized elements. The `this` field contains the `MyApp` instance.
const INNER_OFFSET_IN_POINTERS: usize;
}

impl IUnknown_Vtbl {
Expand Down
41 changes: 37 additions & 4 deletions crates/libs/implement/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
let original_type2 = original_type.clone();
let original_type2 = syn::parse_macro_input!(original_type2 as syn::ItemStruct);
let vis = &original_type2.vis;
let original_ident = original_type2.ident;
let original_ident = &original_type2.ident;
let mut constraints = quote! {};

if let Some(where_clause) = original_type2.generics.where_clause {
if let Some(where_clause) = &original_type2.generics.where_clause {
where_clause.predicates.to_tokens(&mut constraints);
}

Expand Down Expand Up @@ -83,6 +83,25 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
});

// Dynamic casting requires that the object not contain non-static lifetimes.
let enable_dyn_casting = original_type2.generics.lifetimes().count() == 0;
sivadeilra marked this conversation as resolved.
Show resolved Hide resolved
let dynamic_cast_query = if enable_dyn_casting {
quote! {
else if *iid == ::windows_core::DYNAMIC_CAST_IID {
// DYNAMIC_CAST_IID is special. We _do not_ increase the reference count for this pseudo-interface.
// Also, instead of returning an interface pointer, we simply write the `&dyn Any` directly to the
// 'interface' pointer. Since the size of `&dyn Any` is 2 pointers, not one, the caller must be
// prepared for this. This is not a normal QueryInterface call.
//
// See the `Interface::cast_to_any` method, which is the only caller that should use DYNAMIC_CAST_ID.
(interface as *mut *const dyn core::any::Any).write(self as &dyn ::core::any::Any as *const dyn ::core::any::Any);
return ::windows_core::HRESULT(0);
}
}
} else {
quote!()
};

// The distance from the beginning of the generated type to the 'this' field, in units of pointers (not bytes).
let offset_of_this_in_pointers = 1 + attributes.implement.len();
let offset_of_this_in_pointers_token = proc_macro2::Literal::usize_unsuffixed(offset_of_this_in_pointers);
Expand Down Expand Up @@ -201,7 +220,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
|| iid == &<::windows_core::IInspectable as ::windows_core::Interface>::IID
|| iid == &<::windows_core::imp::IAgileObject as ::windows_core::Interface>::IID {
&self.identity as *const _ as *mut _
} #(#queries)* else {
}
#(#queries)*
#dynamic_cast_query
else {
::core::ptr::null_mut()
};

Expand Down Expand Up @@ -230,7 +252,7 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
unsafe fn Release(self_: *mut Self) -> u32 {
let remaining = (*self_).count.release();
if remaining == 0 {
_ = ::windows_core::imp::Box::from_raw(self_ as *const Self as *mut Self);
_ = ::windows_core::imp::Box::from_raw(self_);
}
remaining
}
Expand All @@ -247,6 +269,17 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
&*((inner as *const Self::Impl as *const *const ::core::ffi::c_void)
.sub(#offset_of_this_in_pointers_token) as *const Self)
}

fn to_object(&self) -> ::windows_core::ComObject<Self::Impl> {
self.count.add_ref();
unsafe {
::windows_core::ComObject::from_raw(
::core::ptr::NonNull::new_unchecked(self as *const Self as *mut Self)
)
}
}

const INNER_OFFSET_IN_POINTERS: usize = #offset_of_this_in_pointers_token;
}

impl #generics #original_ident::#generics where #constraints {
Expand Down
45 changes: 44 additions & 1 deletion crates/tests/implement_core/src/com_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::borrow::Borrow;
use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
use std::sync::Arc;
use windows_core::{
implement, interface, ComObject, IUnknown, IUnknownImpl, IUnknown_Vtbl, InterfaceRef,
implement, interface, ComObject, IUnknown, IUnknownImpl, IUnknown_Vtbl, Interface, InterfaceRef,
};

#[interface("818f2fd1-d479-4398-b286-a93c4c7904d1")]
Expand All @@ -19,8 +19,12 @@ unsafe trait IBar: IUnknown {
fn say_hello(&self);
}

const APP_SIGNATURE: [u8; 8] = *b"cafef00d";

#[implement(IFoo, IBar)]
struct MyApp {
// We use signature to verify field offsets for dynamic casts
signature: [u8; 8],
x: u32,
tombstone: Arc<Tombstone>,
}
Expand Down Expand Up @@ -63,6 +67,7 @@ impl core::fmt::Display for MyApp {
impl Default for MyApp {
fn default() -> Self {
Self {
signature: APP_SIGNATURE,
x: 0,
tombstone: Arc::new(Tombstone::default()),
}
Expand Down Expand Up @@ -109,6 +114,7 @@ impl MyApp {
fn new(x: u32) -> ComObject<Self> {
ComObject::new(Self {
x,
signature: APP_SIGNATURE,
tombstone: Arc::new(Tombstone::default()),
})
}
Expand Down Expand Up @@ -333,6 +339,43 @@ fn from_inner_ref() {
unsafe { ibar.say_hello() };
}

#[test]
fn to_object() {
let app = MyApp::new(42);
let tombstone = app.tombstone.clone();
let app_outer: &MyApp_Impl = &app;

let second_app = app_outer.to_object();
assert!(!tombstone.is_dead());
assert_eq!(second_app.signature, APP_SIGNATURE);

println!("x = {}", unsafe { second_app.get_x() });

drop(second_app);
assert!(!tombstone.is_dead());

drop(app);
assert!(tombstone.is_dead());
}

#[test]
fn dynamic_cast() {
let app = MyApp::new(42);
let unknown = app.to_interface::<IUnknown>();

assert!(!unknown.is_object::<SendableThing>());
assert!(unknown.is_object::<MyApp>());

let dyn_app_ref: &MyApp_Impl = unknown.cast_object_ref::<MyApp>().unwrap();
assert_eq!(dyn_app_ref.signature, APP_SIGNATURE);

let dyn_app_owned: ComObject<MyApp> = unknown.cast_object().unwrap();
assert_eq!(dyn_app_owned.signature, APP_SIGNATURE);

let dyn_app_owned_2: ComObject<MyApp> = ComObject::cast_from(&unknown).unwrap();
assert_eq!(dyn_app_owned_2.signature, APP_SIGNATURE);
}

// This tests that we can place a type that is not Send in a ComObject.
// Compilation is sufficient to test.
#[implement(IBar)]
Expand Down