diff --git a/src/NSubstitute/Core/IProxyFactory.cs b/src/NSubstitute/Core/IProxyFactory.cs index 541aad90..31cd3ed8 100644 --- a/src/NSubstitute/Core/IProxyFactory.cs +++ b/src/NSubstitute/Core/IProxyFactory.cs @@ -2,5 +2,5 @@ namespace NSubstitute.Core; public interface IProxyFactory { - object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments); + object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments); } \ No newline at end of file diff --git a/src/NSubstitute/Core/SubstituteFactory.cs b/src/NSubstitute/Core/SubstituteFactory.cs index 76a9f652..e55c2ffd 100644 --- a/src/NSubstitute/Core/SubstituteFactory.cs +++ b/src/NSubstitute/Core/SubstituteFactory.cs @@ -14,7 +14,7 @@ public class SubstituteFactory(ISubstituteStateFactory substituteStateFactory, I /// public object Create(Type[] typesToProxy, object?[] constructorArguments) { - return Create(typesToProxy, constructorArguments, callBaseByDefault: false); + return Create(typesToProxy, constructorArguments, callBaseByDefault: false, isPartial: false); } /// @@ -33,10 +33,10 @@ public object CreatePartial(Type[] typesToProxy, object?[] constructorArguments) throw new CanNotPartiallySubForInterfaceOrDelegateException(primaryProxyType); } - return Create(typesToProxy, constructorArguments, callBaseByDefault: true); + return Create(typesToProxy, constructorArguments, callBaseByDefault: true, isPartial: true); } - private object Create(Type[] typesToProxy, object?[] constructorArguments, bool callBaseByDefault) + private object Create(Type[] typesToProxy, object?[] constructorArguments, bool callBaseByDefault, bool isPartial) { var substituteState = substituteStateFactory.Create(this); substituteState.CallBaseConfiguration.CallBaseByDefault = callBaseByDefault; @@ -46,7 +46,7 @@ private object Create(Type[] typesToProxy, object?[] constructorArguments, bool var callRouter = callRouterFactory.Create(substituteState, canConfigureBaseCalls); var additionalTypes = typesToProxy.Where(x => x != primaryProxyType).ToArray(); - var proxy = proxyFactory.GenerateProxy(callRouter, primaryProxyType, additionalTypes, constructorArguments); + var proxy = proxyFactory.GenerateProxy(callRouter, primaryProxyType, additionalTypes, isPartial, constructorArguments); return proxy; } diff --git a/src/NSubstitute/Exceptions/TypeForwardingException.cs b/src/NSubstitute/Exceptions/TypeForwardingException.cs new file mode 100644 index 00000000..5ddf2703 --- /dev/null +++ b/src/NSubstitute/Exceptions/TypeForwardingException.cs @@ -0,0 +1,21 @@ +namespace NSubstitute.Exceptions; + +public abstract class TypeForwardingException(string message) : SubstituteException(message) +{ +} + +public sealed class CanNotForwardCallsToClassNotImplementingInterfaceException(Type type) : TypeForwardingException(DescribeProblem(type)) +{ + private static string DescribeProblem(Type type) + { + return string.Format("The provided class '{0}' doesn't implement all requested interfaces. ", type.Name); + } +} + +public sealed class CanNotForwardCallsToAbstractClassException(Type type) : TypeForwardingException(DescribeProblem(type)) +{ + private static string DescribeProblem(Type type) + { + return string.Format("The provided class '{0}' is abstract. ", type.Name); + } +} diff --git a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs index e1c0d1ef..a445f42f 100644 --- a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs +++ b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleDynamicProxyFactory.cs @@ -10,14 +10,14 @@ public class CastleDynamicProxyFactory(ICallFactory callFactory, IArgumentSpecif private readonly ProxyGenerator _proxyGenerator = new ProxyGenerator(); private readonly AllMethodsExceptCallRouterCallsHook _allMethodsExceptCallRouterCallsHook = new AllMethodsExceptCallRouterCallsHook(); - public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { return typeToProxy.IsDelegate() ? GenerateDelegateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments) - : GenerateTypeProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments); + : GenerateTypeProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments); } - private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { VerifyClassHasNotBeenPassedAsAnAdditionalInterface(additionalInterfaces); @@ -31,7 +31,8 @@ private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[ additionalInterfaces, constructorArguments, [proxyIdInterceptor, forwardingInterceptor], - proxyGenerationOptions); + proxyGenerationOptions, + isPartial); forwardingInterceptor.SwitchToFullDispatchMode(); return proxy; @@ -54,7 +55,8 @@ private object GenerateDelegateProxy(ICallRouter callRouter, Type delegateType, additionalInterfaces: null, constructorArguments: null, interceptors: [proxyIdInterceptor, forwardingInterceptor], - proxyGenerationOptions); + proxyGenerationOptions, + isPartial: false); forwardingInterceptor.SwitchToFullDispatchMode(); @@ -75,8 +77,13 @@ private CastleForwardingInterceptor CreateForwardingInterceptor(ICallRouter call private object CreateProxyUsingCastleProxyGenerator(Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments, IInterceptor[] interceptors, - ProxyGenerationOptions proxyGenerationOptions) + ProxyGenerationOptions proxyGenerationOptions, + bool isPartial) { + if (isPartial) + return CreatePartialProxy(typeToProxy, additionalInterfaces, constructorArguments, interceptors, proxyGenerationOptions, isPartial); + + if (typeToProxy.GetTypeInfo().IsInterface) { VerifyNoConstructorArgumentsGivenForInterface(constructorArguments); @@ -96,6 +103,7 @@ private object CreateProxyUsingCastleProxyGenerator(Type typeToProxy, Type[]? ad additionalInterfaces = interfaces; } + return _proxyGenerator.CreateClassProxy(typeToProxy, additionalInterfaces, proxyGenerationOptions, @@ -103,6 +111,32 @@ private object CreateProxyUsingCastleProxyGenerator(Type typeToProxy, Type[]? ad interceptors); } + private object CreatePartialProxy(Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments, IInterceptor[] interceptors, ProxyGenerationOptions proxyGenerationOptions, bool isPartial) + { + if (typeToProxy.GetTypeInfo().IsClass && + additionalInterfaces != null && + additionalInterfaces.Any()) + { + VerifyClassIsNotAbstract(typeToProxy); + VerifyClassImplementsAllInterfaces(typeToProxy, additionalInterfaces); + + var targetObject = Activator.CreateInstance(typeToProxy, constructorArguments); + typeToProxy = additionalInterfaces.First(); + + return _proxyGenerator.CreateInterfaceProxyWithTarget(typeToProxy, + additionalInterfaces, + target: targetObject, + options: proxyGenerationOptions, + interceptors: interceptors); + } + + return _proxyGenerator.CreateClassProxy(typeToProxy, + additionalInterfaces, + proxyGenerationOptions, + constructorArguments, + interceptors); + } + private ProxyGenerationOptions GetOptionsToMixinCallRouterProvider(ICallRouter callRouter) { var options = new ProxyGenerationOptions(_allMethodsExceptCallRouterCallsHook); @@ -116,6 +150,22 @@ private ProxyGenerationOptions GetOptionsToMixinCallRouterProvider(ICallRouter c return options; } + private static void VerifyClassImplementsAllInterfaces(Type classType, IEnumerable additionalInterfaces) + { + if (!additionalInterfaces.All(x => x.GetTypeInfo().IsAssignableFrom(classType.GetTypeInfo()))) + { + throw new CanNotForwardCallsToClassNotImplementingInterfaceException(classType); + } + } + + private static void VerifyClassIsNotAbstract(Type classType) + { + if (classType.GetTypeInfo().IsAbstract) + { + throw new CanNotForwardCallsToAbstractClassException(classType); + } + } + private static void VerifyNoConstructorArgumentsGivenForInterface(object?[]? constructorArguments) { if (HasItems(constructorArguments)) diff --git a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs index 7ebac3e9..ddc15405 100644 --- a/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs +++ b/src/NSubstitute/Proxies/CastleDynamicProxy/CastleInvocationMapper.cs @@ -10,8 +10,7 @@ public virtual ICall Map(IInvocation castleInvocation) Func? baseMethod = null; if (castleInvocation.InvocationTarget != null && castleInvocation.MethodInvocationTarget.IsVirtual && - !castleInvocation.MethodInvocationTarget.IsAbstract && - !castleInvocation.MethodInvocationTarget.IsFinal) + !castleInvocation.MethodInvocationTarget.IsAbstract) { baseMethod = CreateBaseResultInvocation(castleInvocation); } diff --git a/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs b/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs index 9e38a1c1..66ee9e4e 100644 --- a/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs +++ b/src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs @@ -8,9 +8,9 @@ public class DelegateProxyFactory(CastleDynamicProxyFactory objectProxyFactory) { private readonly CastleDynamicProxyFactory _castleObjectProxyFactory = objectProxyFactory ?? throw new ArgumentNullException(nameof(objectProxyFactory)); - public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { // Castle factory can now resolve delegate proxies as well. - return _castleObjectProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments); + return _castleObjectProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments); } } \ No newline at end of file diff --git a/src/NSubstitute/Proxies/ProxyFactory.cs b/src/NSubstitute/Proxies/ProxyFactory.cs index f107ae85..c93ee284 100644 --- a/src/NSubstitute/Proxies/ProxyFactory.cs +++ b/src/NSubstitute/Proxies/ProxyFactory.cs @@ -5,11 +5,11 @@ namespace NSubstitute.Proxies; [Obsolete("This class is deprecated and will be removed in future versions of the product.")] public class ProxyFactory(IProxyFactory delegateFactory, IProxyFactory dynamicProxyFactory) : IProxyFactory { - public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments) + public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments) { var isDelegate = typeToProxy.IsDelegate(); return isDelegate - ? delegateFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments) - : dynamicProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments); + ? delegateFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments) + : dynamicProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments); } } \ No newline at end of file diff --git a/src/NSubstitute/Substitute.cs b/src/NSubstitute/Substitute.cs index f1d471e6..c1710106 100644 --- a/src/NSubstitute/Substitute.cs +++ b/src/NSubstitute/Substitute.cs @@ -89,4 +89,24 @@ public static T ForPartsOf(params object[] constructorArguments) var substituteFactory = SubstitutionContext.Current.SubstituteFactory; return (T)substituteFactory.CreatePartial([typeof(T)], constructorArguments); } + + /// + /// Creates a proxy for a class that implements an interface, forwarding methods and properties to an instance of the class, effectively mimicking a real instance. + /// Both the interface and the class must be provided as parameters. + /// The proxy will log calls made to the interface members and delegate them to an instance of the class. Specific members can be substituted + /// by using When(() => call).DoNotCallBase() or by + /// setting a value to return value for that member. + /// This extension supports sealed classes and non-virtual members, with some limitations. Since the substituted method is non-virtual, internal calls within the object will invoke the original implementation and will not be logged. + /// + /// The interface the substitute will implement. + /// The class type implementing the interface. Must be a class; not a delegate or interface. + /// + /// An object implementing the selected interface. Calls will be forwarded to the actuall methods, but allows parts to be selectively + /// overridden via `Returns` and `When..DoNotCallBase`. + public static TInterface ForTypeForwardingTo(params object[] constructorArguments) + where TInterface : class + { + var substituteFactory = SubstitutionContext.Current.SubstituteFactory; + return (TInterface)substituteFactory.CreatePartial([typeof(TInterface), typeof(TClass)], constructorArguments); + } } \ No newline at end of file diff --git a/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs b/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs index 26d4d778..d9985987 100644 --- a/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs +++ b/tests/NSubstitute.Acceptance.Specs/SubbingForConcreteTypesAndMultipleInterfaces.cs @@ -1,4 +1,5 @@ using NUnit.Framework; +using NUnit.Framework.Legacy; namespace NSubstitute.Acceptance.Specs; @@ -31,6 +32,30 @@ public void Can_sub_for_concrete_type_and_implement_other_interfaces() subAsIFirst.Received().First(); } + [Test] + public void Can_sub_for_abstract_type_and_implement_other_two_interfaces() + { + // test from docs + var substitute = Substitute.For([typeof(IFirst), typeof(ISecond), typeof(ClassWithCtorArgs)], + ["hello world", 5]); + + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + } + + [Test] + public void Can_sub_for_concrete_type_and_implement_other_two_interfaces() + { + // test from docs + var substitute = Substitute.For([typeof(IFirst), typeof(ISecond), typeof(ConcreteClassWithCtorArgs)], + ["hello world", 5]); + + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + ClassicAssert.IsInstanceOf(substitute); + } + [Test] public void Partial_sub() { @@ -90,8 +115,13 @@ public class Partial public virtual int Number() { return -1; } public int GetNumberPlusOne() { return Number() + 1; } } + public abstract class ClassWithCtorArgs(string s, int a) { public string StringFromCtorArg { get; set; } = s; public int IntFromCtorArg { get; set; } = a; } + + public class ConcreteClassWithCtorArgs(string s, int a) : ClassWithCtorArgs(s, a) + { + } } \ No newline at end of file diff --git a/tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs b/tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs new file mode 100644 index 00000000..af4a8fa9 --- /dev/null +++ b/tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs @@ -0,0 +1,104 @@ +using NSubstitute.Exceptions; +using NSubstitute.Extensions; +using NUnit.Framework; + +namespace NSubstitute.Acceptance.Specs; + +public class TypeForwarding +{ + [Test] + public void UseImplementedNonVirtualMethod() + { + var testAbstractClass = Substitute.ForTypeForwardingTo(); + Assert.That(testAbstractClass.MethodReturnsSameInt(1), Is.EqualTo(1)); + Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1)); + testAbstractClass.Received().MethodReturnsSameInt(1); + Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1)); + } + + [Test] + public void UseSubstitutedNonVirtualMethod() + { + var testInterface = Substitute.ForTypeForwardingTo(); + testInterface.Configure().MethodReturnsSameInt(1).Returns(2); + Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2)); + Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(3)); + testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default); + Assert.That(testInterface.CalledTimes, Is.EqualTo(1)); + } + + [Test] + public void UseSubstitutedNonVirtualMethodHonorsDoNotCallBase() + { + var testInterface = Substitute.ForTypeForwardingTo(); + testInterface.Configure().MethodReturnsSameInt(1).Returns(2); + testInterface.WhenForAnyArgs(x => x.MethodReturnsSameInt(default)).DoNotCallBase(); + Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2)); + Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(0)); + testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default); + Assert.That(testInterface.CalledTimes, Is.EqualTo(0)); + } + + [Test] + public void PartialSubstituteCallsConstructorWithParameters() + { + var testInterface = Substitute.ForTypeForwardingTo(50); + Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(1)); + Assert.That(testInterface.CalledTimes, Is.EqualTo(51)); + } + + [Test] + public void PartialSubstituteFailsIfClassDoesntImplementInterface() + { + Assert.Throws( + () => Substitute.ForTypeForwardingTo()); + } + + [Test] + public void PartialSubstituteFailsIfClassIsAbstract() + { + Assert.Throws( + () => Substitute.ForTypeForwardingTo(), "The provided class is abstract."); + } + + public interface ITestInterface + { + public int CalledTimes { get; set; } + + void VoidTestMethod(); + int TestMethodReturnsInt(); + int MethodReturnsSameInt(int i); + } + + public sealed class TestSealedNonVirtualClass : ITestInterface + { + public TestSealedNonVirtualClass(int initialCounter) => CalledTimes = initialCounter; + public TestSealedNonVirtualClass() { } + + public int CalledTimes { get; set; } + + public int TestMethodReturnsInt() => throw new NotImplementedException(); + + public void VoidTestMethod() => throw new NotImplementedException(); + public int MethodReturnsSameInt(int i) + { + CalledTimes++; + return i; + } + } + + public abstract class TestAbstractClassWithInterface : ITestInterface + { + public int CalledTimes { get; set; } + + public abstract int MethodReturnsSameInt(int i); + + public abstract int TestMethodReturnsInt(); + + public abstract void VoidTestMethod(); + } + + public class TestRandomConcreteClass { } + + public abstract class TestAbstractClass { } +} \ No newline at end of file