diff --git a/src/Veldrid/MTL/MTLPipeline.cs b/src/Veldrid/MTL/MTLPipeline.cs index eb6eadd39..0860122f6 100644 --- a/src/Veldrid/MTL/MTLPipeline.cs +++ b/src/Veldrid/MTL/MTLPipeline.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using Veldrid.MetalBindings; namespace Veldrid.MTL @@ -29,9 +28,6 @@ internal class MtlPipeline : Pipeline public override bool IsDisposed => disposed; public override string Name { get; set; } - private static readonly Dictionary render_pipeline_states = new Dictionary(); - private static readonly Dictionary compute_pipeline_states = new Dictionary(); - private static readonly Dictionary depth_stencil_states = new Dictionary(); private bool disposed; private List specializedFunctions; @@ -55,163 +51,151 @@ public MtlPipeline(ref GraphicsPipelineDescription description, MtlGraphicsDevic FillMode = MtlFormats.VdToMtlFillMode(description.RasterizerState.FillMode); ScissorTestEnabled = description.RasterizerState.ScissorTestEnabled; - var stateLookup = new RenderPipelineStateLookup { Shaders = description.ShaderSet, BlendState = description.BlendState, Outputs = description.Outputs }; + var mtlDesc = MTLRenderPipelineDescriptor.New(); - if (!render_pipeline_states.TryGetValue(stateLookup, out var renderPipelineState)) + foreach (var shader in description.ShaderSet.Shaders) { - var mtlDesc = MTLRenderPipelineDescriptor.New(); + var mtlShader = Util.AssertSubtype(shader); + MTLFunction specializedFunction; - foreach (var shader in description.ShaderSet.Shaders) + if (mtlShader.HasFunctionConstants) { - var mtlShader = Util.AssertSubtype(shader); - MTLFunction specializedFunction; + // Need to create specialized MTLFunction. + var constantValues = createConstantValues(description.ShaderSet.Specializations); + specializedFunction = mtlShader.Library.newFunctionWithNameConstantValues(mtlShader.EntryPoint, constantValues); + addSpecializedFunction(specializedFunction); + ObjectiveCRuntime.release(constantValues.NativePtr); - if (mtlShader.HasFunctionConstants) - { - // Need to create specialized MTLFunction. - var constantValues = createConstantValues(description.ShaderSet.Specializations); - specializedFunction = mtlShader.Library.newFunctionWithNameConstantValues(mtlShader.EntryPoint, constantValues); - addSpecializedFunction(specializedFunction); - ObjectiveCRuntime.release(constantValues.NativePtr); + Debug.Assert(specializedFunction.NativePtr != IntPtr.Zero, "Failed to create specialized MTLFunction"); + } + else + specializedFunction = mtlShader.Function; - Debug.Assert(specializedFunction.NativePtr != IntPtr.Zero, "Failed to create specialized MTLFunction"); - } - else - specializedFunction = mtlShader.Function; + if (shader.Stage == ShaderStages.Vertex) + mtlDesc.vertexFunction = specializedFunction; + else if (shader.Stage == ShaderStages.Fragment) mtlDesc.fragmentFunction = specializedFunction; + } - if (shader.Stage == ShaderStages.Vertex) - mtlDesc.vertexFunction = specializedFunction; - else if (shader.Stage == ShaderStages.Fragment) mtlDesc.fragmentFunction = specializedFunction; - } + // Vertex layouts + var vdVertexLayouts = description.ShaderSet.VertexLayouts; + var vertexDescriptor = mtlDesc.vertexDescriptor; - // Vertex layouts - var vdVertexLayouts = description.ShaderSet.VertexLayouts; - var vertexDescriptor = mtlDesc.vertexDescriptor; + for (uint i = 0; i < vdVertexLayouts.Length; i++) + { + uint layoutIndex = ResourceBindingModel == ResourceBindingModel.Improved + ? NonVertexBufferCount + i + : i; + var mtlLayout = vertexDescriptor.layouts[layoutIndex]; + mtlLayout.stride = vdVertexLayouts[i].Stride; + uint stepRate = vdVertexLayouts[i].InstanceStepRate; + mtlLayout.stepFunction = stepRate == 0 ? MTLVertexStepFunction.PerVertex : MTLVertexStepFunction.PerInstance; + mtlLayout.stepRate = Math.Max(1, stepRate); + } - for (uint i = 0; i < vdVertexLayouts.Length; i++) - { - uint layoutIndex = ResourceBindingModel == ResourceBindingModel.Improved - ? NonVertexBufferCount + i - : i; - var mtlLayout = vertexDescriptor.layouts[layoutIndex]; - mtlLayout.stride = vdVertexLayouts[i].Stride; - uint stepRate = vdVertexLayouts[i].InstanceStepRate; - mtlLayout.stepFunction = stepRate == 0 ? MTLVertexStepFunction.PerVertex : MTLVertexStepFunction.PerInstance; - mtlLayout.stepRate = Math.Max(1, stepRate); - } + uint element = 0; - uint element = 0; + for (uint i = 0; i < vdVertexLayouts.Length; i++) + { + uint offset = 0; + var vdDesc = vdVertexLayouts[i]; - for (uint i = 0; i < vdVertexLayouts.Length; i++) + for (uint j = 0; j < vdDesc.Elements.Length; j++) { - uint offset = 0; - var vdDesc = vdVertexLayouts[i]; - - for (uint j = 0; j < vdDesc.Elements.Length; j++) - { - var elementDesc = vdDesc.Elements[j]; - var mtlAttribute = vertexDescriptor.attributes[element]; - mtlAttribute.bufferIndex = ResourceBindingModel == ResourceBindingModel.Improved - ? NonVertexBufferCount + i - : i; - mtlAttribute.format = MtlFormats.VdToMtlVertexFormat(elementDesc.Format); - mtlAttribute.offset = elementDesc.Offset != 0 ? elementDesc.Offset : (UIntPtr)offset; - offset += FormatSizeHelpers.GetSizeInBytes(elementDesc.Format); - element += 1; - } + var elementDesc = vdDesc.Elements[j]; + var mtlAttribute = vertexDescriptor.attributes[element]; + mtlAttribute.bufferIndex = ResourceBindingModel == ResourceBindingModel.Improved + ? NonVertexBufferCount + i + : i; + mtlAttribute.format = MtlFormats.VdToMtlVertexFormat(elementDesc.Format); + mtlAttribute.offset = elementDesc.Offset != 0 ? elementDesc.Offset : (UIntPtr)offset; + offset += FormatSizeHelpers.GetSizeInBytes(elementDesc.Format); + element += 1; } + } - VertexBufferCount = (uint)vdVertexLayouts.Length; - - // Outputs - var outputs = description.Outputs; - var blendStateDesc = description.BlendState; - BlendColor = blendStateDesc.BlendFactor; + VertexBufferCount = (uint)vdVertexLayouts.Length; - if (outputs.SampleCount != TextureSampleCount.Count1) mtlDesc.sampleCount = FormatHelpers.GetSampleCountUInt32(outputs.SampleCount); + // Outputs + var outputs = description.Outputs; + var blendStateDesc = description.BlendState; + BlendColor = blendStateDesc.BlendFactor; - if (outputs.DepthAttachment != null) - { - var depthFormat = outputs.DepthAttachment.Value.Format; - var mtlDepthFormat = MtlFormats.VdToMtlPixelFormat(depthFormat, true); - mtlDesc.depthAttachmentPixelFormat = mtlDepthFormat; + if (outputs.SampleCount != TextureSampleCount.Count1) mtlDesc.sampleCount = FormatHelpers.GetSampleCountUInt32(outputs.SampleCount); - if (FormatHelpers.IsStencilFormat(depthFormat)) - { - HasStencil = true; - mtlDesc.stencilAttachmentPixelFormat = mtlDepthFormat; - } - } + if (outputs.DepthAttachment != null) + { + var depthFormat = outputs.DepthAttachment.Value.Format; + var mtlDepthFormat = MtlFormats.VdToMtlPixelFormat(depthFormat, true); + mtlDesc.depthAttachmentPixelFormat = mtlDepthFormat; - for (uint i = 0; i < outputs.ColorAttachments.Length; i++) + if (FormatHelpers.IsStencilFormat(depthFormat)) { - var attachmentBlendDesc = blendStateDesc.AttachmentStates[i]; - var colorDesc = mtlDesc.colorAttachments[i]; - colorDesc.pixelFormat = MtlFormats.VdToMtlPixelFormat(outputs.ColorAttachments[i].Format, false); - colorDesc.blendingEnabled = attachmentBlendDesc.BlendEnabled; - colorDesc.writeMask = MtlFormats.VdToMtlColorWriteMask(attachmentBlendDesc.ColorWriteMask.GetOrDefault()); - colorDesc.alphaBlendOperation = MtlFormats.VdToMtlBlendOp(attachmentBlendDesc.AlphaFunction); - colorDesc.sourceAlphaBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.SourceAlphaFactor); - colorDesc.destinationAlphaBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.DestinationAlphaFactor); - - colorDesc.rgbBlendOperation = MtlFormats.VdToMtlBlendOp(attachmentBlendDesc.ColorFunction); - colorDesc.sourceRGBBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.SourceColorFactor); - colorDesc.destinationRGBBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.DestinationColorFactor); + HasStencil = true; + mtlDesc.stencilAttachmentPixelFormat = mtlDepthFormat; } + } - mtlDesc.alphaToCoverageEnabled = blendStateDesc.AlphaToCoverageEnabled; - - render_pipeline_states[stateLookup] = renderPipelineState = gd.Device.newRenderPipelineStateWithDescriptor(mtlDesc); - ObjectiveCRuntime.release(mtlDesc.NativePtr); + for (uint i = 0; i < outputs.ColorAttachments.Length; i++) + { + var attachmentBlendDesc = blendStateDesc.AttachmentStates[i]; + var colorDesc = mtlDesc.colorAttachments[i]; + colorDesc.pixelFormat = MtlFormats.VdToMtlPixelFormat(outputs.ColorAttachments[i].Format, false); + colorDesc.blendingEnabled = attachmentBlendDesc.BlendEnabled; + colorDesc.writeMask = MtlFormats.VdToMtlColorWriteMask(attachmentBlendDesc.ColorWriteMask.GetOrDefault()); + colorDesc.alphaBlendOperation = MtlFormats.VdToMtlBlendOp(attachmentBlendDesc.AlphaFunction); + colorDesc.sourceAlphaBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.SourceAlphaFactor); + colorDesc.destinationAlphaBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.DestinationAlphaFactor); + + colorDesc.rgbBlendOperation = MtlFormats.VdToMtlBlendOp(attachmentBlendDesc.ColorFunction); + colorDesc.sourceRGBBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.SourceColorFactor); + colorDesc.destinationRGBBlendFactor = MtlFormats.VdToMtlBlendFactor(attachmentBlendDesc.DestinationColorFactor); } - RenderPipelineState = renderPipelineState; + mtlDesc.alphaToCoverageEnabled = blendStateDesc.AlphaToCoverageEnabled; + + RenderPipelineState = gd.Device.newRenderPipelineStateWithDescriptor(mtlDesc); + ObjectiveCRuntime.release(mtlDesc.NativePtr); if (description.Outputs.DepthAttachment != null) { - if (!depth_stencil_states.TryGetValue(description.DepthStencilState, out var depthStencilState)) - { - var depthDescriptor = MTLUtil.AllocInit( - nameof(MTLDepthStencilDescriptor)); - depthDescriptor.depthCompareFunction = MtlFormats.VdToMtlCompareFunction( - description.DepthStencilState.DepthComparison); - depthDescriptor.depthWriteEnabled = description.DepthStencilState.DepthWriteEnabled; + var depthDescriptor = MTLUtil.AllocInit( + nameof(MTLDepthStencilDescriptor)); + depthDescriptor.depthCompareFunction = MtlFormats.VdToMtlCompareFunction( + description.DepthStencilState.DepthComparison); + depthDescriptor.depthWriteEnabled = description.DepthStencilState.DepthWriteEnabled; - bool stencilEnabled = description.DepthStencilState.StencilTestEnabled; + bool stencilEnabled = description.DepthStencilState.StencilTestEnabled; - if (stencilEnabled) - { - StencilReference = description.DepthStencilState.StencilReference; - - var vdFrontDesc = description.DepthStencilState.StencilFront; - var front = MTLUtil.AllocInit(nameof(MTLStencilDescriptor)); - front.readMask = description.DepthStencilState.StencilReadMask; - front.writeMask = description.DepthStencilState.StencilWriteMask; - front.depthFailureOperation = MtlFormats.VdToMtlStencilOperation(vdFrontDesc.DepthFail); - front.stencilFailureOperation = MtlFormats.VdToMtlStencilOperation(vdFrontDesc.Fail); - front.depthStencilPassOperation = MtlFormats.VdToMtlStencilOperation(vdFrontDesc.Pass); - front.stencilCompareFunction = MtlFormats.VdToMtlCompareFunction(vdFrontDesc.Comparison); - depthDescriptor.frontFaceStencil = front; - - var vdBackDesc = description.DepthStencilState.StencilBack; - var back = MTLUtil.AllocInit(nameof(MTLStencilDescriptor)); - back.readMask = description.DepthStencilState.StencilReadMask; - back.writeMask = description.DepthStencilState.StencilWriteMask; - back.depthFailureOperation = MtlFormats.VdToMtlStencilOperation(vdBackDesc.DepthFail); - back.stencilFailureOperation = MtlFormats.VdToMtlStencilOperation(vdBackDesc.Fail); - back.depthStencilPassOperation = MtlFormats.VdToMtlStencilOperation(vdBackDesc.Pass); - back.stencilCompareFunction = MtlFormats.VdToMtlCompareFunction(vdBackDesc.Comparison); - depthDescriptor.backFaceStencil = back; - - ObjectiveCRuntime.release(front.NativePtr); - ObjectiveCRuntime.release(back.NativePtr); - } - - depth_stencil_states[description.DepthStencilState] = depthStencilState = gd.Device.newDepthStencilStateWithDescriptor(depthDescriptor); - ObjectiveCRuntime.release(depthDescriptor.NativePtr); + if (stencilEnabled) + { + StencilReference = description.DepthStencilState.StencilReference; + + var vdFrontDesc = description.DepthStencilState.StencilFront; + var front = MTLUtil.AllocInit(nameof(MTLStencilDescriptor)); + front.readMask = description.DepthStencilState.StencilReadMask; + front.writeMask = description.DepthStencilState.StencilWriteMask; + front.depthFailureOperation = MtlFormats.VdToMtlStencilOperation(vdFrontDesc.DepthFail); + front.stencilFailureOperation = MtlFormats.VdToMtlStencilOperation(vdFrontDesc.Fail); + front.depthStencilPassOperation = MtlFormats.VdToMtlStencilOperation(vdFrontDesc.Pass); + front.stencilCompareFunction = MtlFormats.VdToMtlCompareFunction(vdFrontDesc.Comparison); + depthDescriptor.frontFaceStencil = front; + + var vdBackDesc = description.DepthStencilState.StencilBack; + var back = MTLUtil.AllocInit(nameof(MTLStencilDescriptor)); + back.readMask = description.DepthStencilState.StencilReadMask; + back.writeMask = description.DepthStencilState.StencilWriteMask; + back.depthFailureOperation = MtlFormats.VdToMtlStencilOperation(vdBackDesc.DepthFail); + back.stencilFailureOperation = MtlFormats.VdToMtlStencilOperation(vdBackDesc.Fail); + back.depthStencilPassOperation = MtlFormats.VdToMtlStencilOperation(vdBackDesc.Pass); + back.stencilCompareFunction = MtlFormats.VdToMtlCompareFunction(vdBackDesc.Comparison); + depthDescriptor.backFaceStencil = back; + + ObjectiveCRuntime.release(front.NativePtr); + ObjectiveCRuntime.release(back.NativePtr); } - DepthStencilState = depthStencilState; + DepthStencilState = gd.Device.newDepthStencilStateWithDescriptor(depthDescriptor); + ObjectiveCRuntime.release(depthDescriptor.NativePtr); } DepthClipMode = description.DepthStencilState.DepthTestEnabled ? MTLDepthClipMode.Clip : MTLDepthClipMode.Clamp; @@ -230,60 +214,52 @@ public MtlPipeline(ref ComputePipelineDescription description, MtlGraphicsDevice description.ThreadGroupSizeY, description.ThreadGroupSizeZ); - var stateLookup = new ComputePipelineStateLookup - { ComputeShader = description.ComputeShader, ResourceLayouts = description.ResourceLayouts, Specializations = description.Specializations }; + var mtlDesc = MTLUtil.AllocInit( + nameof(MTLComputePipelineDescriptor)); + var mtlShader = Util.AssertSubtype(description.ComputeShader); + MTLFunction specializedFunction; - if (!compute_pipeline_states.TryGetValue(stateLookup, out var computePipelineState)) + if (mtlShader.HasFunctionConstants) { - var mtlDesc = MTLUtil.AllocInit( - nameof(MTLComputePipelineDescriptor)); - var mtlShader = Util.AssertSubtype(description.ComputeShader); - MTLFunction specializedFunction; - - if (mtlShader.HasFunctionConstants) - { - // Need to create specialized MTLFunction. - var constantValues = createConstantValues(description.Specializations); - specializedFunction = mtlShader.Library.newFunctionWithNameConstantValues(mtlShader.EntryPoint, constantValues); - addSpecializedFunction(specializedFunction); - ObjectiveCRuntime.release(constantValues.NativePtr); + // Need to create specialized MTLFunction. + var constantValues = createConstantValues(description.Specializations); + specializedFunction = mtlShader.Library.newFunctionWithNameConstantValues(mtlShader.EntryPoint, constantValues); + addSpecializedFunction(specializedFunction); + ObjectiveCRuntime.release(constantValues.NativePtr); - Debug.Assert(specializedFunction.NativePtr != IntPtr.Zero, "Failed to create specialized MTLFunction"); - } - else - specializedFunction = mtlShader.Function; + Debug.Assert(specializedFunction.NativePtr != IntPtr.Zero, "Failed to create specialized MTLFunction"); + } + else + specializedFunction = mtlShader.Function; - mtlDesc.computeFunction = specializedFunction; - var buffers = mtlDesc.buffers; - uint bufferIndex = 0; + mtlDesc.computeFunction = specializedFunction; + var buffers = mtlDesc.buffers; + uint bufferIndex = 0; - foreach (var layout in ResourceLayouts) + foreach (var layout in ResourceLayouts) + { + foreach (var rle in layout.Description.Elements) { - foreach (var rle in layout.Description.Elements) + var kind = rle.Kind; + + if (kind == ResourceKind.UniformBuffer + || kind == ResourceKind.StructuredBufferReadOnly) { - var kind = rle.Kind; - - if (kind == ResourceKind.UniformBuffer - || kind == ResourceKind.StructuredBufferReadOnly) - { - var bufferDesc = buffers[bufferIndex]; - bufferDesc.mutability = MTLMutability.Immutable; - bufferIndex += 1; - } - else if (kind == ResourceKind.StructuredBufferReadWrite) - { - var bufferDesc = buffers[bufferIndex]; - bufferDesc.mutability = MTLMutability.Mutable; - bufferIndex += 1; - } + var bufferDesc = buffers[bufferIndex]; + bufferDesc.mutability = MTLMutability.Immutable; + bufferIndex += 1; + } + else if (kind == ResourceKind.StructuredBufferReadWrite) + { + var bufferDesc = buffers[bufferIndex]; + bufferDesc.mutability = MTLMutability.Mutable; + bufferIndex += 1; } } - - compute_pipeline_states[stateLookup] = computePipelineState = gd.Device.newComputePipelineStateWithDescriptor(mtlDesc); - ObjectiveCRuntime.release(mtlDesc.NativePtr); } - ComputePipelineState = computePipelineState; + ComputePipelineState = gd.Device.newComputePipelineStateWithDescriptor(mtlDesc); + ObjectiveCRuntime.release(mtlDesc.NativePtr); } #region Disposal @@ -293,22 +269,13 @@ public override void Dispose() if (!disposed) { if (RenderPipelineState.NativePtr != IntPtr.Zero) - { - render_pipeline_states.Remove(render_pipeline_states.Single(kvp => kvp.Value.NativePtr == RenderPipelineState.NativePtr).Key); ObjectiveCRuntime.release(RenderPipelineState.NativePtr); - } if (DepthStencilState.NativePtr != IntPtr.Zero) - { - depth_stencil_states.Remove(depth_stencil_states.Single(kvp => kvp.Value.NativePtr == DepthStencilState.NativePtr).Key); ObjectiveCRuntime.release(DepthStencilState.NativePtr); - } if (ComputePipelineState.NativePtr != IntPtr.Zero) - { - compute_pipeline_states.Remove(compute_pipeline_states.Single(kvp => kvp.Value.NativePtr == ComputePipelineState.NativePtr).Key); ObjectiveCRuntime.release(ComputePipelineState.NativePtr); - } if (specializedFunctions != null) { @@ -344,53 +311,5 @@ private void addSpecializedFunction(MTLFunction function) specializedFunctions ??= new List(); specializedFunctions.Add(function); } - - private struct RenderPipelineStateLookup : IEquatable - { - public ShaderSetDescription Shaders; - public OutputDescription Outputs; - public BlendStateDescription BlendState; - - public bool Equals(RenderPipelineStateLookup other) - { - return Shaders.Equals(other.Shaders) && - Outputs.Equals(other.Outputs) && - BlendState.Equals(other.BlendState); - } - - public override bool Equals(object obj) - { - return obj is RenderPipelineStateLookup other && Equals(other); - } - - public override int GetHashCode() - { - return HashCode.Combine(Shaders, Outputs, BlendState); - } - } - - private struct ComputePipelineStateLookup : IEquatable - { - public Shader ComputeShader; - public ResourceLayout[] ResourceLayouts; - public SpecializationConstant[] Specializations; - - public bool Equals(ComputePipelineStateLookup other) - { - return ComputeShader == other.ComputeShader && - Util.ArrayEquals(ResourceLayouts, other.ResourceLayouts) && - Util.ArrayEqualsEquatable(Specializations, other.Specializations); - } - - public override bool Equals(object obj) - { - return obj is ComputePipelineStateLookup other && Equals(other); - } - - public override int GetHashCode() - { - return HashCode.Combine(ComputeShader, HashHelper.Array(ResourceLayouts), HashHelper.Array(Specializations)); - } - } } }