Skip to content

Commit

Permalink
Add vtable entry for covariant IEnumerable<object> for WinRT types to…
Browse files Browse the repository at this point in the history
…o. (#601)

* Add vtable entry for IEnumerable<object> for winrt types too.

* Adding more complete covariance support.

* PR feedback.

* PR feedback.
  • Loading branch information
manodasanW authored Nov 24, 2020
1 parent cca8f40 commit d9f1d21
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 7 deletions.
28 changes: 28 additions & 0 deletions TestComponentCSharp/Class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,34 @@ namespace winrt::TestComponentCSharp::implementation
{
_objectIterableChanged.remove(token);
}
IIterable<IIterable<WF::Point>> Class::IterableOfPointIterablesProperty()
{
return _pointIterableIterable;
}
void Class::IterableOfPointIterablesProperty(IIterable<IIterable<WF::Point>> const& value)
{
for (auto points : value)
{
for (auto point : points)
{
}
}
_pointIterableIterable = value;
}
IIterable<IIterable<WF::IInspectable>> Class::IterableOfObjectIterablesProperty()
{
return _objectIterableIterable;
}
void Class::IterableOfObjectIterablesProperty(IIterable<IIterable<WF::IInspectable>> const& value)
{
for (auto objects : value)
{
for (auto object : objects)
{
}
}
_objectIterableIterable = value;
}
Uri Class::UriProperty()
{
return _uri;
Expand Down
6 changes: 6 additions & 0 deletions TestComponentCSharp/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace winrt::TestComponentCSharp::implementation
Windows::Foundation::IInspectable _object;
winrt::event<Windows::Foundation::EventHandler<Windows::Foundation::IInspectable>> _objectChanged;
Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable> _objectIterable;
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::Point>> _pointIterableIterable;
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> _objectIterableIterable;
winrt::event<Windows::Foundation::EventHandler<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>>> _objectIterableChanged;
Windows::Foundation::Uri _uri;
winrt::event<Windows::Foundation::EventHandler<Windows::Foundation::Uri>> _uriChanged;
Expand Down Expand Up @@ -153,6 +155,10 @@ namespace winrt::TestComponentCSharp::implementation
void CallForObjectIterable(TestComponentCSharp::ProvideObjectIterable const& provideObjectIterable);
winrt::event_token ObjectIterablePropertyChanged(Windows::Foundation::EventHandler<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> const& handler);
void ObjectIterablePropertyChanged(winrt::event_token const& token) noexcept;
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::Point>> IterableOfPointIterablesProperty();
void IterableOfPointIterablesProperty(Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::Point>> const& value);
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> IterableOfObjectIterablesProperty();
void IterableOfObjectIterablesProperty(Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> const& value);
Windows::Foundation::Uri UriProperty();
void UriProperty(Windows::Foundation::Uri const& value);
void RaiseUriChanged();
Expand Down
2 changes: 2 additions & 0 deletions TestComponentCSharp/TestComponentCSharp.idl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ namespace TestComponentCSharp
void RaiseObjectIterableChanged();
void CallForObjectIterable(ProvideObjectIterable provideObjectIterable);
event Windows.Foundation.EventHandler<Windows.Foundation.Collections.IIterable<Object> > ObjectIterablePropertyChanged;
Windows.Foundation.Collections.IIterable<Windows.Foundation.Collections.IIterable<Windows.Foundation.Point> > IterableOfPointIterablesProperty;
Windows.Foundation.Collections.IIterable<Windows.Foundation.Collections.IIterable<Object> > IterableOfObjectIterablesProperty;

Windows.Foundation.Uri UriProperty;
void RaiseUriChanged();
Expand Down
41 changes: 40 additions & 1 deletion UnitTest/TestComponentCSharp_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,26 @@ public void TestObjectCasting()

var objects = new List<ManagedType>() { new ManagedType(), new ManagedType() };
var query = from item in objects select item;
TestObject.ObjectIterableProperty = query;
TestObject.ObjectIterableProperty = query;

TestObject.ObjectProperty = "test";
Assert.Equal("test", TestObject.ObjectProperty);

var objectArray = new ManagedType[] { new ManagedType(), new ManagedType() };
TestObject.ObjectIterableProperty = objectArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(objectArray));

var strArray = new string[] { "str1", "str2", "str3" };
TestObject.ObjectIterableProperty = strArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(strArray));

var uriArray = new Uri[] { new Uri("http://aka.ms/cswinrt"), new Uri("http://github.com") };
TestObject.ObjectIterableProperty = uriArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(uriArray));

var objectUriArray = new object[] { new Uri("http://github.com") };
TestObject.ObjectIterableProperty = objectUriArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(objectUriArray));
}

[Fact]
Expand Down Expand Up @@ -2250,6 +2269,26 @@ public void TestIBindableVector()
{
CustomBindableVectorTest vector = new CustomBindableVectorTest();
Assert.NotNull(vector);
}

[Fact]
public void TestCovariance()
{
var listOfListOfPoints = new List<List<Point>>() {
new List<Point>{ new Point(1, 1), new Point(1, 2), new Point(1, 3) },
new List<Point>{ new Point(2, 1), new Point(2, 2), new Point(2, 3) },
new List<Point>{ new Point(3, 1), new Point(3, 2), new Point(3, 3) }
};
TestObject.IterableOfPointIterablesProperty = listOfListOfPoints;
Assert.True(TestObject.IterableOfPointIterablesProperty.SequenceEqual(listOfListOfPoints));

var listOfListOfUris = new List<List<Uri>>() {
new List<Uri>{ new Uri("http://aka.ms/cswinrt"), new Uri("http://github.com") },
new List<Uri>{ new Uri("http://aka.ms/cswinrt") },
new List<Uri>{ new Uri("http://aka.ms/cswinrt"), new Uri("http://microsoft.com") }
};
TestObject.IterableOfObjectIterablesProperty = listOfListOfUris;
Assert.True(TestObject.IterableOfObjectIterablesProperty.SequenceEqual(listOfListOfUris));
}
}
}
15 changes: 9 additions & 6 deletions WinRT.Runtime/ComWrappersSupport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,17 @@ internal static List<ComInterfaceEntry> GetInterfaceTableEntries(object obj)
}

if (iface.IsConstructedGenericType
&& Projections.TryGetCompatibleWindowsRuntimeTypeForVariantType(iface, out var compatibleIface))
&& Projections.TryGetCompatibleWindowsRuntimeTypesForVariantType(iface, out var compatibleIfaces))
{
var compatibleIfaceAbiType = compatibleIface.FindHelperType();
entries.Add(new ComInterfaceEntry
foreach (var compatibleIface in compatibleIfaces)
{
IID = GuidGenerator.GetIID(compatibleIfaceAbiType),
Vtable = (IntPtr)compatibleIfaceAbiType.GetAbiToProjectionVftblPtr()
});
var compatibleIfaceAbiType = compatibleIface.FindHelperType();
entries.Add(new ComInterfaceEntry
{
IID = GuidGenerator.GetIID(compatibleIfaceAbiType),
Vtable = (IntPtr)compatibleIfaceAbiType.GetAbiToProjectionVftblPtr()
});
}
}
}

Expand Down
114 changes: 114 additions & 0 deletions WinRT.Runtime/Projections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Collections.Specialized;
using System.ComponentModel;
using System.Linq;
using System.Numerics;
using System.Reflection;
using System.Threading;
Expand Down Expand Up @@ -206,6 +207,7 @@ private static bool IsTypeWindowsRuntimeTypeNoArray(Type type)
|| type.GetCustomAttribute<WindowsRuntimeTypeAttribute>() is object;
}

// Use TryGetCompatibleWindowsRuntimeTypesForVariantType instead.
public static bool TryGetCompatibleWindowsRuntimeTypeForVariantType(Type type, out Type compatibleType)
{
compatibleType = null;
Expand Down Expand Up @@ -247,6 +249,118 @@ public static bool TryGetCompatibleWindowsRuntimeTypeForVariantType(Type type, o
return true;
}

private static HashSet<Type> GetCompatibleTypes(Type type)
{
HashSet<Type> compatibleTypes = new HashSet<Type>();

foreach (var iface in type.GetInterfaces())
{
if (IsTypeWindowsRuntimeTypeNoArray(iface))
{
compatibleTypes.Add(iface);
}

if (iface.IsConstructedGenericType
&& TryGetCompatibleWindowsRuntimeTypesForVariantType(iface, out var compatibleIfaces))
{
compatibleTypes.UnionWith(compatibleIfaces);
}
}

Type baseType = type.BaseType;
while (baseType != null)
{
if (IsTypeWindowsRuntimeTypeNoArray(baseType))
{
compatibleTypes.Add(baseType);
}
baseType = baseType.BaseType;
}

return compatibleTypes;
}

internal static IEnumerable<Type> GetAllPossibleTypeCombinations(IEnumerable<IEnumerable<Type>> compatibleTypesPerGeneric, Type definition)
{
// Implementation adapted from https://stackoverflow.com/a/4424005
var accum = new List<Type>();
var compatibleTypesPerGenericArray = compatibleTypesPerGeneric.ToArray();
if (compatibleTypesPerGenericArray.Length > 0)
{
GetAllPossibleTypeCombinationsCore(
accum,
new Stack<Type>(),
compatibleTypesPerGenericArray,
compatibleTypesPerGenericArray.Length - 1);
}
return accum;

void GetAllPossibleTypeCombinationsCore(List<Type> accum, Stack<Type> stack, IEnumerable<Type>[] compatibleTypes, int index)
{
foreach (var type in compatibleTypes[index])
{
stack.Push(type);
if (index == 0)
{
// IEnumerable on a System.Collections.Generic.Stack
// enumerates in order of removal (last to first).
// As a result, we get the correct ordering here.
accum.Add(definition.MakeGenericType(stack.ToArray()));
}
else
{
GetAllPossibleTypeCombinationsCore(accum, stack, compatibleTypes, index - 1);
}
stack.Pop();
}
}
}

internal static bool TryGetCompatibleWindowsRuntimeTypesForVariantType(Type type, out IEnumerable<Type> compatibleTypes)
{
compatibleTypes = null;
if (!type.IsConstructedGenericType)
{
throw new ArgumentException(nameof(type));
}

var definition = type.GetGenericTypeDefinition();

if (!IsTypeWindowsRuntimeTypeNoArray(definition))
{
return false;
}

var genericConstraints = definition.GetGenericArguments();
var genericArguments = type.GetGenericArguments();
List<List<Type>> compatibleTypesPerGeneric = new List<List<Type>>();
for (int i = 0; i < genericArguments.Length; i++)
{
List<Type> compatibleTypesForGeneric = new List<Type>();
bool argumentCovariantObject = (genericConstraints[i].GenericParameterAttributes & GenericParameterAttributes.VarianceMask) == GenericParameterAttributes.Covariant
&& !genericArguments[i].IsValueType;

if (IsTypeWindowsRuntimeTypeNoArray(genericArguments[i]))
{
compatibleTypesForGeneric.Add(genericArguments[i]);
}
else if (!argumentCovariantObject)
{
return false;
}

if (argumentCovariantObject)
{
compatibleTypesForGeneric.AddRange(GetCompatibleTypes(genericArguments[i]));
}

compatibleTypesPerGeneric.Add(compatibleTypesForGeneric);
}

compatibleTypes = GetAllPossibleTypeCombinations(compatibleTypesPerGeneric, definition);
return true;
}

internal static bool TryGetDefaultInterfaceTypeForRuntimeClassType(Type runtimeClass, out Type defaultInterface)
{
runtimeClass = runtimeClass.GetRuntimeClassCCWType() ?? runtimeClass;
Expand Down

0 comments on commit d9f1d21

Please sign in to comment.