diff --git a/CMakeLists.txt b/CMakeLists.txt index e1225022e3..6580e26806 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,6 +101,8 @@ option(SLANG_ENABLE_TESTS "Enable test targets, some tests may require SLANG_ENA option(SLANG_ENABLE_EXAMPLES "Enable example targets, requires SLANG_ENABLE_GFX" ON) option(SLANG_ENABLE_REPLAYER "Enable slang-replay tool" ON) +option(SLANG_GITHUB_TOKEN "Use a given token value for accessing Github REST API" "") + enum_option( SLANG_LIB_TYPE # Default @@ -136,7 +138,7 @@ enum_option( if(SLANG_SLANG_LLVM_FLAVOR MATCHES FETCH_BINARY) # If the user didn't specify a URL, find the best one now if(NOT SLANG_SLANG_LLVM_BINARY_URL) - get_best_slang_binary_release_url(url) + get_best_slang_binary_release_url("${SLANG_GITHUB_TOKEN}" url) if(NOT DEFINED url) message(FATAL_ERROR "Unable to find binary release for slang-llvm, please set a different SLANG_SLANG_LLVM_FLAVOR or set SLANG_SLANG_LLVM_BINARY_URL manually") endif() diff --git a/cmake/GitHubRelease.cmake b/cmake/GitHubRelease.cmake index e63fb7885b..dd9dd8fe18 100644 --- a/cmake/GitHubRelease.cmake +++ b/cmake/GitHubRelease.cmake @@ -1,4 +1,4 @@ -function(check_release_and_get_latest owner repo version os arch out_var) +function(check_release_and_get_latest owner repo version os arch github_token out_var) # Construct the URL for the specified version's release API endpoint set(version_url "https://api.github.com/repos/${owner}/${repo}/releases/tags/v${version}") @@ -17,8 +17,22 @@ function(check_release_and_get_latest owner repo version os arch out_var) set(${found_var} "${found}" PARENT_SCOPE) endfunction() - # Download the specified release info from GitHub - file(DOWNLOAD "${version_url}" "${json_output_file}" STATUS download_statuses) + # Prepare download arguments + set(download_args + "${version_url}" + "${json_output_file}" + STATUS download_statuses + ) + + if(github_token) + # Add authorization header if token is provided + list(APPEND download_args HTTPHEADER "Authorization: token ${github_token}") + endif() + + # Perform the download + file(DOWNLOAD ${download_args}) + + # Check if the downloading was successful list(GET download_statuses 0 status_code) if(status_code EQUAL 0) file(READ "${json_output_file}" json_content) @@ -34,6 +48,10 @@ function(check_release_and_get_latest owner repo version os arch out_var) message(WARNING "Failed to find ${desired_zip} in release assets for ${version} from ${version_url}\nFalling back to latest version if it differs") else() message(WARNING "Failed to download release info for version ${version} from ${version_url}\nFalling back to latest version if it differs") + + if(status_code EQUAL 22) + message(WARNING "If API rate limit is exceeded, Github allows a higher limit when you use token. Try a cmake option -DSLANG_GITHUB_TOKEN=your_token_here") + endif() endif() @@ -69,7 +87,7 @@ function(check_release_and_get_latest owner repo version os arch out_var) endif() endfunction() -function(get_best_slang_binary_release_url out_var) +function(get_best_slang_binary_release_url github_token out_var) if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64|AMD64") set(arch "x86_64") elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|ARM64|arm64") @@ -93,7 +111,7 @@ function(get_best_slang_binary_release_url out_var) set(owner "shader-slang") set(repo "slang") - check_release_and_get_latest(${owner} ${repo} ${SLANG_VERSION_NUMERIC} ${os} ${arch} release_version) + check_release_and_get_latest(${owner} ${repo} ${SLANG_VERSION_NUMERIC} ${os} ${arch} "${github_token}" release_version) if(DEFINED release_version) set(${out_var} "https://github.com/${owner}/${repo}/releases/download/v${release_version}/slang-${release_version}-${os}-${arch}.zip" PARENT_SCOPE) endif() diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md index df51ac9509..3d35d2bf52 100644 --- a/docs/user-guide/06-interfaces-generics.md +++ b/docs/user-guide/06-interfaces-generics.md @@ -1036,7 +1036,8 @@ Slang supports the following builtin interfaces: - `IInteger`, represents a logical integer that supports both `IArithmetic` and `ILogical` operations. Implemented by all builtin integer scalar types. - `IDifferentiable`, represents a value that is differentiable. - `IFloat`, represents a logical float that supports both `IArithmetic`, `ILogical` and `IDifferentiable` operations. Also provides methods to convert to and from `float`. Implemented by all builtin floating-point scalar, vector and matrix types. -- `IArray`, represents a logical array that supports retrieving an element of type `T` from an index. Implemented by array types, vectors and matrices. +- `IArray`, represents a logical array that supports retrieving an element of type `T` from an index. Implemented by array types, vectors, matrices and `StructuredBuffer`. +- `IRWArray`, represents a logical array whose elements are mutable. Implemented by array types, vectors, matrices, `RWStructuredBuffer` and `RasterizerOrderedStructuredBuffer`. - `IFunc` represent a callable object (with `operator()`) that returns `TResult` and takes `TParams...` as argument. - `IMutatingFunc`, similar to `IFunc`, but the `operator()` method is `[mutating]`. - `IDifferentiableFunc`, similar to `IFunc`, but the `operator()` method is `[Differentiable]`. diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index b3a25358c8..9209a903c1 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -136,8 +136,13 @@ Where the backward propagation function $$\mathbb{B}[f_i]$$ takes as input the p The higher order operator $$\mathbb{F}$$ and $$\mathbb{B}$$ represent the operations that converts an original or primal function $$f$$ to its forward or backward derivative propagation function. Slang's automatic differentiation feature provide built-in support for these operators to automatically generate the derivative propagation functions from a user defined primal function. The remaining documentation will discuss this feature from a programming language perspective. -## Differentiable Types -Slang will only generate differentiation code for values that has a *differentiable* type. A type is differentiable if it conforms to the built-in `IDifferentiable` interface. The definition of the `IDifferentiable` interface is: +## Differentiable Value Types +Slang will only generate differentiation code for values that has a *differentiable* type. +Differentiable types are defining through conformance to one of two built-in interfaces: +1. `IDifferentiable`: For value types (e.g. `float`, structs of value types, etc..) +2. `IDifferentiablePtrType`: For buffer, pointer & reference types that represent locations rather than values. + +The `IDifferentiable` interface requires the following definitions (which can be auto-generated by the compiler for most scenarios) ```csharp interface IDifferentiable { @@ -147,23 +152,38 @@ interface IDifferentiable static Differential dzero(); static Differential dadd(Differential, Differential); - - static Differential dmul(This, Differential); } ``` As defined by the `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of the second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`. -In addition, a differentiable type must define the `zero` value of its derivative, and how to add and multiply derivative values. +In addition, a differentiable type must define the `zero` value of its derivative, and how to add two derivative values together. These function are used during reverse-mode auto-diff, to initialize and accumulate derivatives of the given type. -### Builtin Differentiable Types +By contrast, `IDifferentiablePtrType` only requires a `Differential` associated type which also conforms to `IDifferentiablePtrType`. +```csharp +interface IDifferentiablePtrType +{ + associatedtype Differential : IDifferentiablePtrType; + where Differential.Differential == Differential; +} +``` + +Types should not conform to both `IDifferentiablePtrType` and `IDifferentiable`. Such cases will result in a compiler error. + + +### Builtin Differentiable Value Types The following built-in types are differentiable: - Scalars: `float`, `double` and `half`. - Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types. - Arrays: `T[n]` is differentiable if `T` is differentiable. +- Tuples: `Tuple` is differentiable if `T` is differentiable. + +### Builtin Differentiable Ptr Types +There are currently no built-in types that conform to `IDifferentiablePtrType` ### User Defined Differentiable Types -The user can make any `struct` types differentiable by implementing the `IDifferentiable` interface on the type. The requirements from `IDifferentiable` interface can be fulfilled automatically or manually. +The user can make any `struct` types differentiable by implementing either `IDifferentiable` & `IDifferentiablePtrType` interface on the type. +The requirements from `IDifferentiable` interface can be fulfilled automatically or manually, though `IDifferentiablePtrType` currently requires the user to provide the `Differential` type. #### Automatic Fulfillment of `IDifferentiable` Requirements Assume the user has defined the following type: @@ -191,7 +211,7 @@ Note that this code does not provide any explicit implementation of the `IDiffer 1. A new type is generated that stores the `Differential` of all differentiable fields. This new type itself will conform to the `IDifferentiable` interface, and it will be used to satisfy the `Differential` associated type requirement. 2. Each differential field will be associated to its corresponding field in the newly synthesized `Differential` type. 3. The `zero` value of the differential type is made from the `zero` value of each field in the differential type. -4. The `dadd` and `dmul` methods simply perform `dadd` and `dmul` operations on each field. +4. The `dadd` method invokes the `dadd` operations for each field whose type conforms to `IDifferentiable`. 5. If the synthesized `Differential` type contains exactly the same fields as the original type, and the type of each field is the same as the original field type, then the original type itself will be used as the `Differential` type instead of creating a new type to satisfy the `Differential` associated type requirement. This means that all the synthesized `Differential` type use itself to meet its own `IDifferentiable` requirements. #### Manual Fulfillment of `IDifferentiable` Requirements @@ -235,15 +255,6 @@ struct MyRay : IDifferentiable result.d_dir = v1.d_dir + v2.d_dir; return result; } - - // Define the multiply operation of a primal value and a derivative value. - static MyRayDifferential dmul(MyRay p, MyRayDifferential d) - { - MyRayDifferential result; - result.d_origin = p.origin * d.d_origin; - result.d_dir = p.dir * d.d_dir; - return result; - } } ``` @@ -284,14 +295,6 @@ struct MyRayDifferential : IDifferentiable result.d_dir = v1.d_dir + v2.d_dir; return result; } - - static MyRayDifferential dmul(MyRayDifferential p, MyRayDifferential d) - { - MyRayDifferential result; - result.d_origin = p.d_origin * d.d_origin; - result.d_dir = p.d_dir * d.d_dir; - return result; - } } ``` In this specific case, the automatically generated `IDifferentiable` implementation will be exactly the same as the manually written code listed above. @@ -303,17 +306,19 @@ Functions in Slang can be marked as forward-differentiable or backward-different A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters. Given an original function, the signature of its forward propagation function is determined using the following rules: -- If the return type `R` is differentiable, the forward propagation function will return `DifferentialPair` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. -- If a parameter has type `T` that is differentiable, it will be translated into a `DifferentialPair` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters. +- If the return type `R` implements `IDifferentiable` the forward propagation function will return a corresponding `DifferentialPair` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. +- If a parameter has type `T` that implements `IDifferentiable`, it will be translated into a `DifferentialPair` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters. +- If a parameter has type `T` that implements `IDifferentiablePtrType`, it will be translated into a `DifferentialPtrPair` parameter where the differential component references the differential location or buffer. - All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function. +- Differentiable methods cannot have a type implementing `IDifferentiablePtrType` as an `out` or `inout` parameter, or a return type. Types implementing `IDifferentiablePtrType` can only be used for input parameters to a differentiable method. Marking such a method as `[Differentiable]` will result in a compile-time diagnostic error. For example, given original function: ```csharp -R original(T0 p0, inout T1 p1, T2 p2); +R original(T0 p0, inout T1 p1, T2 p2, T3 p3); ``` -Where `R`, `T0`, and `T1` is differentiable and `T2` is non-differentiable, the forward derivative function will have the following signature: +Where `R`, `T0`, `T1 : IDifferentiable`, `T2` is non-differentiable, and `T3 : IDifferentiablePtrType`, the forward derivative function will have the following signature: ```csharp -DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2); +DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2, DifferentialPtrPair p3); ``` This forward propagation function takes the initial primal value of `p0` in `p0.p`, and the partial derivative of `p0` with regard to some upstream parameter in `p0.d`. It takes the initial primal and derivative values of `p1` and updates `p1` to hold the newly computed value and propagated derivative. Since `p2` is not differentiable, it remains unchanged. @@ -327,10 +332,11 @@ struct DifferentialPair : IDifferentiable property T.Differential d {get;} static Differential dzero(); static Differential dadd(Differential a, Differential b); - static Differential dmul(This a, Differential b); } ``` +For ptr-types, there is a corresponding built-in `DifferentialPtrPair` that does not have the `dzero` or `dadd` methods. + ### Automatic Implementation of Forward Derivative Functions A function can be made forward-differentiable with a `[ForwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the forward propagation function. The syntax for using `[ForwardDifferentiable]` is: @@ -392,23 +398,25 @@ Given an original function `f`, the general rule for determining the signature o More specifically, the signature of its backward propagation function is determined using the following rules: - A backward propagation function always returns `void`. -- A differentiable `in` parameter of type `T` will become an `inout DifferentialPair` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input. -- A differentiable `out` parameter of type `T` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value. -- A differentiable `inout` parameter of type `T` will become an `inout DifferentialPair` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated. +- A differentiable `in` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input. +- A differentiable `out` parameter of type `T : IDifferentiable` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value. +- A differentiable `inout` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated. - A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the backward propagation function parameter list, carrying the result derivative of a downstream term with regard to the return value of the original function. - A non-differentiable return value of type `NDR` will be dropped. - A non-differentiable `in` parameter of type `ND` will remain unchanged in the backward propagation function. - A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the backward propagation function. - A non-differentiable `inout` parameter of type `ND` will become an `in ND` parameter. +- Types implemented `IDifferentiablePtrType` work the same was as the forward-mode case. They can only be used with `in` parameters, and are converted into `DifferentialPtrPair` types. Their directions are not affected. For example consider the following original function: ```csharp struct T : IDifferentiable {...} struct R : IDifferentiable {...} +struct P : IDifferentiablePtrType {...} struct ND {} // Non differentiable [Differentiable] -R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5); +R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5, P p6); ``` The signature of its backward propagation function is: ```csharp @@ -418,6 +426,7 @@ void back_prop( inout DifferentialPair p2, ND p3, ND p5, + DifferentialPtrPair

p6, R.Differential dResult); ``` Note that although `p2` is still `inout` in the backward propagation function, the backward propagation function will only write propagated derivative to `p2.d` and will not modify `p2.p`. @@ -447,6 +456,7 @@ void back_prop( inout DifferentialPair p2, ND p3, ND p5, + DifferentialPtrPair

p6, R.Differential dResult) { ... @@ -468,6 +478,7 @@ void back_prop( inout DifferentialPair p2, ND p3, ND p5, + DifferentialPtrPair

p6, R.Differential dResult) { ... diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index e72e416073..d482d59e8e 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -70,6 +70,9 @@ target_include_directories(imgui INTERFACE "${CMAKE_CURRENT_LIST_DIR}/imgui") set(SLANG_RHI_SLANG_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include) set(SLANG_RHI_SLANG_BINARY_DIR ${CMAKE_BINARY_DIR}) set(SLANG_RHI_BUILD_TESTS OFF) +if(SLANG_ENABLE_DX_ON_VK) + set(SLANG_RHI_HAS_D3D12 ON) +endif() add_subdirectory(slang-rhi) # Tidy things up: diff --git a/external/slang-rhi b/external/slang-rhi index 54317882dd..b3b0e7384e 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit 54317882dd20c131f8c9bcdbe844242763d58039 +Subproject commit b3b0e7384e8bd78280b9dcf97422f65ac877773e diff --git a/include/slang.h b/include/slang.h index 3024aa8844..3bcdcbba8c 100644 --- a/include/slang.h +++ b/include/slang.h @@ -852,6 +852,7 @@ extern "C" EmitIr, // bool ReportDownstreamTime, // bool ReportPerfBenchmark, // bool + ReportCheckpointIntermediates, // bool SkipSPIRVValidation, // bool SourceEmbedStyle, SourceEmbedName, diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h index 7226edc04c..8c140cf3d6 100644 --- a/source/slang-record-replay/util/emum-to-string.h +++ b/source/slang-record-replay/util/emum-to-string.h @@ -149,6 +149,7 @@ namespace SlangRecord CASE(EmitIr); CASE(ReportDownstreamTime); CASE(ReportPerfBenchmark); + CASE(ReportCheckpointIntermediates); CASE(SkipSPIRVValidation); CASE(SourceEmbedStyle); CASE(SourceEmbedName); diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 03dda0fe5a..476279ab8d 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -285,6 +285,13 @@ interface IDifferentiable static Differential dmul(T, Differential); }; +__magic_type(DifferentiablePtrType) +interface IDifferentiablePtrType +{ + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) ) + associatedtype Differential : IDifferentiablePtrType; +}; + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. @@ -357,6 +364,36 @@ struct DifferentialPair : IDifferentiable } }; +__generic +__magic_type(DifferentialPtrPairType) +__intrinsic_type($(kIROp_DifferentialPtrPairType)) +struct DifferentialPtrPair : IDifferentiablePtrType +{ + typedef DifferentialPtrPair Differential; + typedef T.Differential DifferentialElementType; + + __intrinsic_op($(kIROp_MakeDifferentialPtrPair)) + __init(T _primal, T.Differential _differential); + + property p : T + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) + get; + } + + property v : T + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) + get; + } + + property d : T.Differential + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetDifferential)) + get; + } +}; + /// A type that uses a floating-point representation [sealed] @@ -396,6 +433,15 @@ interface IArray } } +interface IRWArray : IArray +{ + __subscript(int index)->T + { + get; + set; + } +} + // The "comma operator" is effectively just a generic function that returns its second // argument. The left-to-right evaluation order guaranteed by Slang then ensures that // `left` is evaluated before `right`. @@ -1148,21 +1194,15 @@ extension int16_t : IRangedValue __generic __magic_type(ArrayExpressionType) -struct Array : IArray +struct Array : IRWArray { __intrinsic_op($(kIROp_GetArrayLength)) int getCount(); - - __subscript(int index) -> T - { - __intrinsic_op($(kIROp_GetElement)) - get; - } } /// An `N` component vector with elements of type `T`. __generic __magic_type(VectorExpressionType) -struct vector : IArray +struct vector : IRWArray { /// The element type of the vector typedef T Element; @@ -1182,8 +1222,6 @@ struct vector : IArray [ForceInline] int getCount() { return N; } - - __subscript(int index) -> T { __intrinsic_op($(kIROp_GetElement)) get; } } const int kRowMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_ROW_MAJOR); @@ -1192,7 +1230,7 @@ const int kColumnMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_COLUMN_MAJOR); /// A matrix with `R` rows and `C` columns, with elements of type `T`. __generic __magic_type(MatrixExpressionType) -struct matrix : IArray> +struct matrix : IRWArray> { __intrinsic_op($(kIROp_MakeMatrixFromScalar)) __implicit_conversion($(kConversionCost_ScalarToMatrix)) @@ -1207,8 +1245,6 @@ struct matrix : IArray> [ForceInline] int getCount() { return R; } - - __subscript(int index) -> vector { __intrinsic_op($(kIROp_GetElement)) get; } } __intrinsic_op($(kIROp_Eql)) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 56a249be7f..49c9dea299 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5775,7 +5775,7 @@ bool all(matrix x) // Barrier for writes to all memory spaces (HLSL SM 5.0) __glsl_extension(GL_KHR_memory_scope_semantics) -[require(cuda_glsl_hlsl_metal_spirv, memorybarrier)] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, memorybarrier)] void AllMemoryBarrier() { __target_switch @@ -5788,12 +5788,13 @@ void AllMemoryBarrier() { OpMemoryBarrier Device AcquireRelease|UniformMemory|WorkgroupMemory|ImageMemory; }; + case wgsl: __intrinsic_asm "storageBarrier(); textureBarrier(); workgroupBarrier();"; } } // Thread-group sync and barrier for writes to all memory spaces (HLSL SM 5.0) __glsl_extension(GL_KHR_memory_scope_semantics) -[require(cuda_glsl_hlsl_metal_spirv, memorybarrier)] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, memorybarrier)] void AllMemoryBarrierWithGroupSync() { __target_switch @@ -5806,6 +5807,7 @@ void AllMemoryBarrierWithGroupSync() { OpControlBarrier Workgroup Device AcquireRelease|UniformMemory|WorkgroupMemory|ImageMemory; }; + case wgsl: __intrinsic_asm "storageBarrier(); textureBarrier(); workgroupBarrier();"; } } @@ -7540,7 +7542,7 @@ T determinant(matrix m) // Barrier for device memory __glsl_extension(GL_KHR_memory_scope_semantics) -[require(cuda_glsl_hlsl_metal_spirv, memorybarrier)] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, memorybarrier)] void DeviceMemoryBarrier() { __target_switch @@ -7553,11 +7555,12 @@ void DeviceMemoryBarrier() { OpMemoryBarrier Device AcquireRelease|UniformMemory|ImageMemory; }; + case wgsl: __intrinsic_asm "storageBarrier(); textureBarrier(); workgroupBarrier();"; } } __glsl_extension(GL_KHR_memory_scope_semantics) -[require(cuda_glsl_hlsl_metal_spirv, memorybarrier)] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, memorybarrier)] void DeviceMemoryBarrierWithGroupSync() { __target_switch @@ -7570,6 +7573,7 @@ void DeviceMemoryBarrierWithGroupSync() { OpControlBarrier Workgroup Device AcquireRelease|UniformMemory|ImageMemory; }; + case wgsl: __intrinsic_asm "storageBarrier(); textureBarrier(); workgroupBarrier();"; } } @@ -8932,7 +8936,7 @@ float2 GetRenderTargetSamplePosition(int Index) // Group memory barrier __glsl_extension(GL_KHR_memory_scope_semantics) -[require(cuda_glsl_hlsl_metal_spirv, memorybarrier)] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, memorybarrier)] void GroupMemoryBarrier() { __target_switch @@ -8946,6 +8950,7 @@ void GroupMemoryBarrier() { OpMemoryBarrier Workgroup AcquireRelease|WorkgroupMemory }; + case wgsl: __intrinsic_asm "workgroupBarrier"; } } @@ -8967,7 +8972,7 @@ void __subgroupBarrier() } __glsl_extension(GL_KHR_memory_scope_semantics) -[require(cuda_glsl_hlsl_metal_spirv, memorybarrier)] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, memorybarrier)] void GroupMemoryBarrierWithGroupSync() { __target_switch @@ -8981,6 +8986,7 @@ void GroupMemoryBarrierWithGroupSync() { OpControlBarrier Workgroup Workgroup AcquireRelease|WorkgroupMemory }; + case wgsl: __intrinsic_asm "workgroupBarrier"; } } @@ -20617,3 +20623,18 @@ uint64_t clockARB() }; } } + +extension StructuredBuffer : IArray +{ + int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; } +} + +extension RWStructuredBuffer : IRWArray +{ + int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; } +} + +extension RasterizerOrderedStructuredBuffer : IRWArray +{ + int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; } +} \ No newline at end of file diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 6f3789c7e4..9339f5dee6 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -757,6 +757,8 @@ class Expr : public SyntaxNode QualType type; + bool checked = false; + void accept(IExprVisitor* visitor, void* extra); }; diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 9879a41872..b66af34fa4 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness) + Witness* diffTypeWitness) { - Val* args[] = { valueType, primalIsDifferentialWitness }; + Val* args[] = { valueType, diffTypeWitness }; return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } +DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness) +{ + Val* args[] = { valueType, diffRefTypeWitness }; + return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType")); +} + DeclRef ASTBuilder::getDifferentiableInterfaceDecl() { DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiableType", nullptr)); return declRef; } +DeclRef ASTBuilder::getDifferentiableRefInterfaceDecl() +{ + DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiablePtrType", nullptr)); + return declRef; +} + bool ASTBuilder::isDifferentiableInterfaceAvailable() { return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr); @@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType() return DeclRefType::create(this, getDifferentiableInterfaceDecl()); } +Type* ASTBuilder::getDifferentiableRefInterfaceType() +{ + return DeclRefType::create(this, getDifferentiableRefInterfaceDecl()); +} + DeclRef ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index b9b1f7ab85..08951513dc 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -489,10 +489,17 @@ class ASTBuilder : public RefObject DifferentialPairType* getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness); + Witness* diffTypeWitness); + + DifferentialPtrPairType* getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness); DeclRef getDifferentiableInterfaceDecl(); + DeclRef getDifferentiableRefInterfaceDecl(); + Type* getDifferentiableInterfaceType(); + Type* getDifferentiableRefInterfaceType(); bool isDifferentiableInterfaceAvailable(); diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 8916db7b4a..665c4cb2c6 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -9,7 +9,7 @@ #include "slang-profile.h" #include "slang-type-system-shared.h" -#include "slang.h" +#include "../../include/slang.h" #include "../core/slang-semantic-version.h" @@ -1607,6 +1607,7 @@ namespace Slang DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement + DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement DZeroFunc, ///< The `IDifferentiable.dzero` function requirement DAddFunc, ///< The `IDifferentiable.dadd` function requirement DMulFunc, ///< The `IDifferentiable.dmul` function requirement diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 37a2e34fe1..081d136ef1 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -462,11 +462,22 @@ class DifferentialPairType : public ArithmeticExpressionType Type* getPrimalType(); }; +class DifferentialPtrPairType : public ArithmeticExpressionType +{ + SLANG_AST_CLASS(DifferentialPtrPairType) + Type* getPrimalRefType(); +}; + class DifferentiableType : public BuiltinType { SLANG_AST_CLASS(DifferentiableType) }; +class DifferentiablePtrType : public BuiltinType +{ + SLANG_AST_CLASS(DifferentiablePtrType) +}; + class DefaultInitializableType : public BuiltinType { SLANG_AST_CLASS(DefaultInitializableType); diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index e8020aa04c..68a55e5675 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -461,7 +461,7 @@ Val* DeclaredSubtypeWitness::_resolveImplOverride() ConversionCost DeclaredSubtypeWitness::_getOverloadResolutionCostOverride() { - return kConversionCost_GenericParamUpcast; + return kConversionCost_None; } Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) @@ -611,7 +611,7 @@ Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, S ConversionCost TransitiveSubtypeWitness::_getOverloadResolutionCostOverride() { - return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost(); + return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost() + kConversionCost_GenericParamUpcast; } void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 96f5996a0c..42a7db2133 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -322,6 +322,10 @@ alias cuda_glsl_hlsl_spirv = cuda | glsl | hlsl | spirv; /// [Compound] alias cuda_glsl_hlsl_metal_spirv = cuda | glsl | hlsl | metal | spirv; +/// CUDA, GLSL, HLSL, Metal, SPIRV and WGSL code-gen targets +/// [Compound] +alias cuda_glsl_hlsl_metal_spirv_wgsl = cuda | glsl | hlsl | metal | spirv | wgsl; + /// CUDA, GLSL, and SPIRV code-gen targets /// [Compound] alias cuda_glsl_spirv = cuda | glsl | spirv; diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index ffa0379962..9d9047e417 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -274,9 +274,14 @@ namespace Slang return isInterfaceType(type); } - bool SemanticsVisitor::isTypeDifferentiable(Type* type) + SubtypeWitness* SemanticsVisitor::isTypeDifferentiable(Type* type) { - return isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None); + if (auto valueWitness = isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None)) + return valueWitness; + else if (auto ptrWitness = isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None)) + return ptrWitness; + + return nullptr; } bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5acd302158..32fe693b1f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4081,7 +4081,7 @@ namespace Slang void SemanticsVisitor::addModifiersToSynthesizedDecl( ConformanceCheckingContext* context, DeclRef requiredMemberDeclRef, - FunctionDeclBase* synthesized, + CallableDecl* synthesized, ThisExpr*& synThis) { // Required interface methods can be `static` or non-`static`, @@ -4234,13 +4234,13 @@ namespace Slang } } - FunctionDeclBase* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( + CallableDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, + DeclRef requiredMemberDeclRef, List& synArgs, ThisExpr*& synThis) { - FunctionDeclBase* synFuncDecl = as(m_astBuilder->createByNodeType(requiredMemberDeclRef.getDecl()->astNodeType)); + CallableDecl* synFuncDecl = as(m_astBuilder->createByNodeType(requiredMemberDeclRef.getDecl()->astNodeType)); SLANG_ASSERT(synFuncDecl); synFuncDecl->ownedScope = m_astBuilder->create(); synFuncDecl->ownedScope->containerDecl = synFuncDecl; @@ -4381,8 +4381,8 @@ namespace Slang ThisExpr* synThis = nullptr; List synArgs; - auto synFuncDecl = synthesizeMethodSignatureForRequirementWitness( - context, requiredMemberDeclRef, synArgs, synThis); + auto synFuncDecl = as(synthesizeMethodSignatureForRequirementWitness( + context, requiredMemberDeclRef, synArgs, synThis)); auto resultType = synFuncDecl->returnType.type; @@ -4710,6 +4710,7 @@ namespace Slang // Synthesize the property name with a prefix to avoid name clashing. synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; synPropertyDecl->nameAndLoc.name = getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); + synPropertyDecl->parentDecl = context->parentDecl; // The type of our synthesized property will be the expected type @@ -4727,43 +4728,176 @@ namespace Slang auto propertyType = getType(m_astBuilder, requiredMemberDeclRef); synPropertyDecl->type.type = propertyType; - // Our synthesized property will have an accessor declaration for - // each accessor of the requirement. + + // We start by constructing an expression that represents + // `this.name` where `name` is the name of the required + // member. The caller already passed in a `lookupResult` + // that should indicate all the declarations found by + // looking up `name`, so we can start with that. + // + // TODO: Note that there are many cases for member lookup + // that are not handled just by using `createLookupResultExpr` + // because they are currently being special-cased (the most + // notable cases are swizzles, as well as lookup of static + // members in types). + // + // The main result here is that we will not be able to synthesize + // a requirement for a built-in scalar/vector/matrix type to + // a property with a name like `.xy` based on the presence of + // swizles, even though it seems like such a thing should Just Work. + // + // If this is important we could "fix" it by allowing this + // code to dispatch to the special-case logic used when doing + // semantic checking for member expressions. + // + // Note: an alternative would be to change the stdlib declarations + // of vectors/matrices so that all the swizzles are defined as + // `property` declarations. There are some C++ math libraries (like GLM) + // that implement swizzle syntax by a similar approach of statically + // enumerating all possible swizzles. The down-side to such an + // approach is that the combinatorial space of swizzles is quite + // large (especially for matrices) so that supporting them via + // general-purpose language features is unlikely to be as efficient + // as special-case logic. + // + // We are going to synthesize an expression and then perform + // semantic checking on it, but if there are semantic errors + // we do *not* want to report them to the user as such, and + // instead want the result to be a failure to synthesize + // a valid witness. + // + // We will buffer up diagnostics into a temporary sink and + // then throw them away when we are done. + // + // TODO: This behavior might be something we want to make + // into a more fundamental capability of `DiagnosticSink` and/or + // `SemanticsVisitor` so that code can push/pop the emission + // of diagnostics more easily. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); + + // We need to create a `this` expression to be used in the body + // of the synthesized accessor. // - // TODO: If we ever start to support synthesis for subscript requirements, - // then we probably want to factor the accessor-related logic into - // a subroutine so that it can be shared between properties and subscripts. + // TODO: if we ever allow `static` properties or subscripts, + // we will need to handle that case here, by *not* creating + // a `this` expression. // - Dictionary, AccessorDecl*> mapRequiredAccessorToSynAccessor; - for( auto requiredAccessorDeclRef : getMembersOfType(m_astBuilder, requiredMemberDeclRef) ) + ThisExpr* synThis = m_astBuilder->create(); + synThis->scope = synPropertyDecl->ownedScope; + + // The type of `this` in our accessor will be the type for + // which we are synthesizing a conformance. + // + synThis->type.type = context->conformingType; + synThis->type.isLeftValue = true; + auto synMemberRef = subVisitor.createLookupResultExpr( + requiredMemberDeclRef.getName(), + lookupResult, + synThis, + requiredMemberDeclRef.getLoc(), + nullptr); + synMemberRef->loc = requiredMemberDeclRef.getLoc(); + + bool canSynAccessors = synthesizeAccessorRequirements( + context, + requiredMemberDeclRef, + propertyType, + synMemberRef, + synPropertyDecl, + witnessTable); + if (!canSynAccessors) + return false; + + + + + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier()) { + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synPropertyDecl, visibility); + } + return true; + } + + bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) + { + // We are synthesizing a property requirement for a wrapper type: + // + // interface IFoo { property value : int { get; set; } } + // struct Foo : IFoo = FooImpl; + // + // We need to synthesize Foo to: + // + // struct Foo : IFoo + // { + // FooImpl inner; + // property value : int { get { return inner.value; } + // set { inner.value = newValue; } + // } + // } + // + // To do so, we need to grab the witness table of FooImpl:IFoo, and create + // wrapper property in Foo that forwards the accessors to the inner object. + // + // We get started by constructing a synthesized `PropertyDecl`. + // + auto synPropertyDecl = m_astBuilder->create(); + synPropertyDecl->parentDecl = context->parentDecl; + + // Synthesize the property name with a prefix to avoid name clashing. + // + synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + synPropertyDecl->nameAndLoc.name = getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); + + // Find the witness that FooImpl : IFoo. + auto aggTypeDecl = as(context->parentDecl); + auto innerType = aggTypeDecl->wrappedType.type; + DeclRef innerProperty; + auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); + if (!innerWitness) + return false; + + for (auto requiredAccessorDeclRef : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + { + auto innerEntry = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredAccessorDeclRef.getDecl()); + if (innerEntry.getFlavor() != RequirementWitness::Flavor::declRef) + return false; + auto innerAccessorDeclRef = as(innerEntry.getDeclRef()); + if (!innerAccessorDeclRef) + return false; + // The synthesized accessor will be an AST node of the same class as // the required accessor. // - auto synAccessorDecl = (AccessorDecl*) m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType); + auto synAccessorDecl = (AccessorDecl*)m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType); synAccessorDecl->ownedScope = m_astBuilder->create(); synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); - // Whatever the required accessor returns, that is what our synthesized accessor will return. + // The return type should be the same as the inner object's accessor return type. // - synAccessorDecl->returnType.type = getResultType(m_astBuilder, requiredAccessorDeclRef); + synAccessorDecl->returnType.type = getResultType(m_astBuilder, innerAccessorDeclRef); - // Similarly, our synthesized accessor will have parameters matching those of the requirement. - // - // Note: in practice we expect that only `set` accessors will have any parameters, - // and they will only have a single parameter. + // Similarly, our synthesized accessor will have parameters matching those of the inner accessor. // List synArgs; - for( auto requiredParamDeclRef : getParameters(m_astBuilder, requiredAccessorDeclRef) ) + for (auto innerParamDeclRef : getParameters(m_astBuilder, innerAccessorDeclRef)) { - auto paramType = getType(m_astBuilder, requiredParamDeclRef); + auto paramType = getType(m_astBuilder, innerParamDeclRef); // The synthesized parameter will ahve the same name and // type as the parameter of the requirement. // auto synParamDecl = m_astBuilder->create(); - synParamDecl->nameAndLoc = requiredParamDeclRef.getDecl()->nameAndLoc; + synParamDecl->nameAndLoc = innerParamDeclRef.getDecl()->nameAndLoc; synParamDecl->type.type = paramType; // We need to add the parameter as a child declaration of @@ -4781,163 +4915,35 @@ namespace Slang synArgs.add(synArg); } - // We need to create a `this` expression to be used in the body - // of the synthesized accessor. - // - // TODO: if we ever allow `static` properties or subscripts, - // we will need to handle that case here, by *not* creating - // a `this` expression. - // - ThisExpr* synThis = m_astBuilder->create(); - synThis->scope = synAccessorDecl->ownedScope; - - // The type of `this` in our accessor will be the type for - // which we are synthesizing a conformance. - // - synThis->type.type = context->conformingType; - - // A `get` accessor should default to an immutable `this`, - // while other accessors default to mutable `this`. - // - // TODO: If we ever add other kinds of accessors, we will - // need to check that this assumption stays valid. - // - synThis->type.isLeftValue = true; - if(as(requiredAccessorDeclRef)) - synThis->type.isLeftValue = false; - - // If the accessor requirement is `[nonmutating]` then our - // synthesized accessor should be too, and also the `this` - // parameter should *not* be an l-value. - // - if( requiredAccessorDeclRef.getDecl()->hasModifier() ) - { - synThis->type.isLeftValue = false; - - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - // - // Note: we don't currently support `[mutating] get` accessors, - // but the desired behavior in that case is clear, so we go - // ahead and future-proof this code a bit: - // - else if( requiredAccessorDeclRef.getDecl()->hasModifier() ) - { - synThis->type.isLeftValue = true; - - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - else if (requiredAccessorDeclRef.getDecl()->hasModifier()) - { - synThis->type.isLeftValue = true; - - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - else if (requiredAccessorDeclRef.getDecl()->hasModifier()) - { - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - // We are going to synthesize an expression and then perform - // semantic checking on it, but if there are semantic errors - // we do *not* want to report them to the user as such, and - // instead want the result to be a failure to synthesize - // a valid witness. - // - // We will buffer up diagnostics into a temporary sink and - // then throw them away when we are done. - // - // TODO: This behavior might be something we want to make - // into a more fundamental capability of `DiagnosticSink` and/or - // `SemanticsVisitor` so that code can push/pop the emission - // of diagnostics more easily. - // - DiagnosticSink tempSink(getSourceManager(), nullptr); - SemanticsVisitor subVisitor(withSink(&tempSink)); - - // We start by constructing an expression that represents - // `this.name` where `name` is the name of the required - // member. The caller already passed in a `lookupResult` - // that should indicate all the declarations found by - // looking up `name`, so we can start with that. - // - // TODO: Note that there are many cases for member lookup - // that are not handled just by using `createLookupResultExpr` - // because they are currently being special-cased (the most - // notable cases are swizzles, as well as lookup of static - // members in types). - // - // The main result here is that we will not be able to synthesize - // a requirement for a built-in scalar/vector/matrix type to - // a property with a name like `.xy` based on the presence of - // swizles, even though it seems like such a thing should Just Work. - // - // If this is important we could "fix" it by allowing this - // code to dispatch to the special-case logic used when doing - // semantic checking for member expressions. - // - // Note: an alternative would be to change the stdlib declarations - // of vectors/matrices so that all the swizzles are defined as - // `property` declarations. There are some C++ math libraries (like GLM) - // that implement swizzle syntax by a similar approach of statically - // enumerating all possible swizzles. The down-side to such an - // approach is that the combinatorial space of swizzles is quite - // large (especially for matrices) so that supporting them via - // general-purpose language features is unlikely to be as efficient - // as special-case logic. - // - auto synMemberRef = subVisitor.createLookupResultExpr( - requiredMemberDeclRef.getName(), - lookupResult, - synThis, - requiredMemberDeclRef.getLoc(), - nullptr); - synMemberRef->loc = requiredMemberDeclRef.getLoc(); - + // Now synthesize the body of the property accessor. // The body of the accessor will depend on the class of the accessor // we are synthesizing (e.g., `get` vs. `set`). // Stmt* synBodyStmt = nullptr; - if( as(requiredAccessorDeclRef) ) + auto propertyRef = m_astBuilder->create(); + propertyRef->scope = synAccessorDecl->ownedScope; + auto base = m_astBuilder->create(); + base->scope = propertyRef->scope; + base->name = getName("inner"); + propertyRef->baseExpression = base; + innerProperty = innerAccessorDeclRef.getParent(); + propertyRef->name = getParentDecl(innerAccessorDeclRef.getDecl())->getName(); + auto checkedPropertyRefExpr = CheckExpr(propertyRef); + + if (as(requiredAccessorDeclRef)) { - // A `get` accessor will simply perform: - // - // return this.name; - // - // which involves coercing the member access `this.name` to - // the expected type of the property. - // - auto coercedMemberRef = subVisitor.coerce(CoercionSite::Return, propertyType, synMemberRef); auto synReturn = m_astBuilder->create(); - synReturn->expression = coercedMemberRef; + synReturn->expression = checkedPropertyRefExpr; synBodyStmt = synReturn; } - else if( as(requiredAccessorDeclRef) ) + else if (as(requiredAccessorDeclRef)) { - // We expect all `set` accessors to have a single argument, - // but we will defensively bail out if that is somehow - // not the case. - // - SLANG_ASSERT(synArgs.getCount() == 1); - if(synArgs.getCount() != 1) - return false; - - // A `set` accessor will simply perform: - // - // this.name = newValue; - // - // which involves creating and checking an assignment - // expression. - auto synAssign = m_astBuilder->create(); - synAssign->left = synMemberRef; + synAssign->left = checkedPropertyRefExpr; synAssign->right = synArgs[0]; - auto synCheckedAssign = subVisitor.checkAssignWithCheckedOperands(synAssign); + auto synCheckedAssign = checkAssignWithCheckedOperands(synAssign); auto synExprStmt = m_astBuilder->create(); synExprStmt->expression = synCheckedAssign; @@ -4953,93 +4959,83 @@ namespace Slang return false; } - // We bail out if we ran into any errors (meaning that the synthesized - // accessor is not usable). - // - // TODO: If there were *warnings* emitted to the sink, it would probably - // be good to show those warnings to the user, since they might indicate - // real issues. E.g., with the current logic a `float` field could - // satisfying an `int` property requirement, but the user would probably - // want to be warned when they do such a thing. - // - if(tempSink.getErrorCount() != 0) - return false; - + addModifier(synAccessorDecl, m_astBuilder->create()); synAccessorDecl->body = synBodyStmt; synAccessorDecl->parentDecl = synPropertyDecl; synPropertyDecl->members.add(synAccessorDecl); - // If synthesis of an accessor worked, then we will record it into - // a local dictionary. We do *not* install the accessor into the - // witness table yet, because it is possible that synthesis will - // succeed for some accessors but not others, and we don't want - // to leave the witness table in a state where a requirement is - // "partially satisfied." + // Register the synthesized accessor. // - mapRequiredAccessorToSynAccessor.add(requiredAccessorDeclRef, synAccessorDecl); + witnessTable->add(requiredAccessorDeclRef.getDecl(), RequirementWitness(makeDeclRef(synAccessorDecl))); } - synPropertyDecl->parentDecl = context->parentDecl; + // The type of our synthesized property will be the same as the inner property. + // + auto propertyType = getType(m_astBuilder, as(innerProperty)); + synPropertyDecl->type.type = propertyType; - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requiredMemberDeclRef.getDecl()->findModifier()) + // The visibility of synthesized decl should be the same as the inner requirement + if (innerProperty.getDecl()->findModifier()) { - auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synPropertyDecl, visibility); + auto vis = getDeclVisibility(innerProperty.getDecl()); + addVisibilityModifier(m_astBuilder, synPropertyDecl, vis); } - // Once our synthesized declaration is complete, we need - // to install it as the witness that satifies the given - // requirement. - // - // Subsequent code generation should not be able to tell the - // difference between our synthetic property and a hand-written - // one with the same behavior. - // - for(auto& [key, value] : mapRequiredAccessorToSynAccessor) - { - witnessTable->add(key.getDecl(), RequirementWitness(makeDeclRef(value))); - } + context->parentDecl->addMember(synPropertyDecl); witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(makeDeclRef(synPropertyDecl))); return true; } - bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( + bool SemanticsVisitor::trySynthesizeAssociatedTypeRequirementWitness( ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, + LookupResult const& inLookupResult, + DeclRef requiredMemberDeclRef, RefPtr witnessTable) { - // We are synthesizing a property requirement for a wrapper type: - // - // interface IFoo { property value : int { get; set; } } - // struct Foo : IFoo = FooImpl; - // - // We need to synthesize Foo to: - // - // struct Foo : IFoo - // { - // FooImpl inner; - // property value : int { get { return inner.value; } - // set { inner.value = newValue; } - // } - // } - // - // To do so, we need to grab the witness table of FooImpl:IFoo, and create - // wrapper property in Foo that forwards the accessors to the inner object. - // - // We get started by constructing a synthesized `PropertyDecl`. - // - auto synPropertyDecl = m_astBuilder->create(); - synPropertyDecl->parentDecl = context->parentDecl; + SLANG_UNUSED(inLookupResult); - // Synthesize the property name with a prefix to avoid name clashing. - // - synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - synPropertyDecl->nameAndLoc.name = getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); + // The only case we can synthesize for now is when the conformant type + // is a wrapper type. + if (!isWrapperTypeDecl(context->parentDecl)) + return false; + auto aggTypeDecl = as(context->parentDecl); + auto lookupResult = lookUpMember( + m_astBuilder, + this, + requiredMemberDeclRef.getName(), + aggTypeDecl->wrappedType.type, + aggTypeDecl->ownedScope, + LookupMask::Default, + LookupOptions::IgnoreBaseInterfaces); + if (!lookupResult.isValid() || lookupResult.isOverloaded()) + return false; + auto assocType = DeclRefType::create(m_astBuilder, lookupResult.item.declRef); + witnessTable->add(requiredMemberDeclRef.getDecl(), assocType); + for (auto typeConstraintDecl : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + { + auto witness = tryGetSubtypeWitness(assocType, getSup(m_astBuilder, typeConstraintDecl)); + if (!witness) + return false; + witnessTable->add(typeConstraintDecl.getDecl(), witness); + } + return true; + } + + bool SemanticsVisitor::trySynthesizeAssociatedConstantRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& inLookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) + { + SLANG_UNUSED(inLookupResult); + + // The only case we can synthesize for now is when the conformant type + // is a wrapper type, i.e. + // struct Foo:IFoo = FooImpl; + if (!isWrapperTypeDecl(context->parentDecl)) + return false; // Find the witness that FooImpl : IFoo. auto aggTypeDecl = as(context->parentDecl); @@ -5049,15 +5045,24 @@ namespace Slang if (!innerWitness) return false; + auto witness = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredMemberDeclRef.getDecl()); + if (witness.getFlavor() != RequirementWitness::Flavor::val) + return false; + witnessTable->add(requiredMemberDeclRef.getDecl(), witness.getVal()); + return true; + } + + bool SemanticsVisitor::synthesizeAccessorRequirements( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + Type* resultType, + Expr* synBoundStorageExpr, + ContainerDecl* synAccesorContainer, + RefPtr witnessTable) + { + Dictionary, AccessorDecl*> mapRequiredAccessorToSynAccessor; for (auto requiredAccessorDeclRef : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) { - auto innerEntry = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredAccessorDeclRef.getDecl()); - if (innerEntry.getFlavor() != RequirementWitness::Flavor::declRef) - return false; - auto innerAccessorDeclRef = as(innerEntry.getDeclRef()); - if (!innerAccessorDeclRef) - return false; - // The synthesized accessor will be an AST node of the same class as // the required accessor. // @@ -5066,68 +5071,160 @@ namespace Slang synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); - // The return type should be the same as the inner object's accessor return type. + // Whatever the required accessor returns, that is what our synthesized accessor will return. // - synAccessorDecl->returnType.type = getResultType(m_astBuilder, innerAccessorDeclRef); + synAccessorDecl->returnType.type = resultType; - // Similarly, our synthesized accessor will have parameters matching those of the inner accessor. + // Similarly, our synthesized accessor will have parameters matching those of the requirement. + // + // Note: in practice we expect that only `set` accessors will have any parameters, + // and they will only have a single parameter. // List synArgs; - for (auto innerParamDeclRef : getParameters(m_astBuilder, innerAccessorDeclRef)) + for (auto requiredParamDeclRef : getParameters(m_astBuilder, requiredAccessorDeclRef)) { - auto paramType = getType(m_astBuilder, innerParamDeclRef); + auto paramType = getType(m_astBuilder, requiredParamDeclRef); // The synthesized parameter will ahve the same name and // type as the parameter of the requirement. // auto synParamDecl = m_astBuilder->create(); - synParamDecl->nameAndLoc = innerParamDeclRef.getDecl()->nameAndLoc; + synParamDecl->nameAndLoc = requiredParamDeclRef.getDecl()->nameAndLoc; synParamDecl->type.type = paramType; - // We need to add the parameter as a child declaration of - // the accessor we are building. - // - synParamDecl->parentDecl = synAccessorDecl; - synAccessorDecl->members.add(synParamDecl); + // We need to add the parameter as a child declaration of + // the accessor we are building. + // + synParamDecl->parentDecl = synAccessorDecl; + synAccessorDecl->members.add(synParamDecl); + + // For each paramter, we will create an argument expression + // to represent it in the body of the accessor. + // + auto synArg = m_astBuilder->create(); + synArg->declRef = makeDeclRef(synParamDecl); + synArg->type = paramType; + synArgs.add(synArg); + } + + // We need to create a `this` expression to be used in the body + // of the synthesized accessor. + // + // TODO: if we ever allow `static` properties or subscripts, + // we will need to handle that case here, by *not* creating + // a `this` expression. + // + ThisExpr* synThis = m_astBuilder->create(); + synThis->scope = synAccessorDecl->ownedScope; + + // The type of `this` in our accessor will be the type for + // which we are synthesizing a conformance. + // + synThis->type.type = context->conformingType; + + // A `get` accessor should default to an immutable `this`, + // while other accessors default to mutable `this`. + // + // TODO: If we ever add other kinds of accessors, we will + // need to check that this assumption stays valid. + // + synThis->type.isLeftValue = true; + if (as(requiredAccessorDeclRef)) + synThis->type.isLeftValue = false; + + // If the accessor requirement is `[nonmutating]` then our + // synthesized accessor should be too, and also the `this` + // parameter should *not* be an l-value. + // + if (requiredAccessorDeclRef.getDecl()->hasModifier()) + { + synThis->type.isLeftValue = false; + + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; + } + // + // Note: we don't currently support `[mutating] get` accessors, + // but the desired behavior in that case is clear, so we go + // ahead and future-proof this code a bit: + // + else if (requiredAccessorDeclRef.getDecl()->hasModifier()) + { + synThis->type.isLeftValue = true; + + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; + } + else if (requiredAccessorDeclRef.getDecl()->hasModifier()) + { + synThis->type.isLeftValue = true; - // For each paramter, we will create an argument expression - // to represent it in the body of the accessor. - // - auto synArg = m_astBuilder->create(); - synArg->declRef = makeDeclRef(synParamDecl); - synArg->type = paramType; - synArgs.add(synArg); + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; + } + else if (requiredAccessorDeclRef.getDecl()->hasModifier()) + { + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; } + // We are going to synthesize an expression and then perform + // semantic checking on it, but if there are semantic errors + // we do *not* want to report them to the user as such, and + // instead want the result to be a failure to synthesize + // a valid witness. + // + // We will buffer up diagnostics into a temporary sink and + // then throw them away when we are done. + // + // TODO: This behavior might be something we want to make + // into a more fundamental capability of `DiagnosticSink` and/or + // `SemanticsVisitor` so that code can push/pop the emission + // of diagnostics more easily. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); - // Now synthesize the body of the property accessor. // The body of the accessor will depend on the class of the accessor // we are synthesizing (e.g., `get` vs. `set`). // Stmt* synBodyStmt = nullptr; - auto propertyRef = m_astBuilder->create(); - propertyRef->scope = synAccessorDecl->ownedScope; - auto base = m_astBuilder->create(); - base->scope = propertyRef->scope; - base->name = getName("inner"); - propertyRef->baseExpression = base; - innerProperty = innerAccessorDeclRef.getParent(); - propertyRef->name = getParentDecl(innerAccessorDeclRef.getDecl())->getName(); - auto checkedPropertyRefExpr = CheckExpr(propertyRef); - if (as(requiredAccessorDeclRef)) { + // A `get` accessor will simply perform: + // + // return this.name; + // + // which involves coercing the member access `this.name` to + // the expected type of the property. + // + auto coercedMemberRef = subVisitor.coerce(CoercionSite::Return, resultType, synBoundStorageExpr); auto synReturn = m_astBuilder->create(); - synReturn->expression = checkedPropertyRefExpr; + synReturn->expression = coercedMemberRef; synBodyStmt = synReturn; } else if (as(requiredAccessorDeclRef)) { + // We expect all `set` accessors to have a single argument, + // but we will defensively bail out if that is somehow + // not the case. + // + SLANG_ASSERT(synArgs.getCount() == 1); + if (synArgs.getCount() != 1) + return false; + + // A `set` accessor will simply perform: + // + // this.name = newValue; + // + // which involves creating and checking an assignment + // expression. + auto synAssign = m_astBuilder->create(); - synAssign->left = checkedPropertyRefExpr; + synAssign->left = synBoundStorageExpr; synAssign->right = synArgs[0]; - auto synCheckedAssign = checkAssignWithCheckedOperands(synAssign); + auto synCheckedAssign = subVisitor.checkAssignWithCheckedOperands(synAssign); auto synExprStmt = m_astBuilder->create(); synExprStmt->expression = synCheckedAssign; @@ -5143,96 +5240,221 @@ namespace Slang return false; } - addModifier(synAccessorDecl, m_astBuilder->create()); + // We bail out if we ran into any errors (meaning that the synthesized + // accessor is not usable). + // + // TODO: If there were *warnings* emitted to the sink, it would probably + // be good to show those warnings to the user, since they might indicate + // real issues. E.g., with the current logic a `float` field could + // satisfying an `int` property requirement, but the user would probably + // want to be warned when they do such a thing. + // + if (tempSink.getErrorCount() != 0) + return false; + synAccessorDecl->body = synBodyStmt; - synAccessorDecl->parentDecl = synPropertyDecl; - synPropertyDecl->members.add(synAccessorDecl); + synAccessorDecl->parentDecl = synAccesorContainer; + synAccesorContainer->members.add(synAccessorDecl); - // Register the synthesized accessor. + // If synthesis of an accessor worked, then we will record it into + // a local dictionary. We do *not* install the accessor into the + // witness table yet, because it is possible that synthesis will + // succeed for some accessors but not others, and we don't want + // to leave the witness table in a state where a requirement is + // "partially satisfied." // - witnessTable->add(requiredAccessorDeclRef.getDecl(), RequirementWitness(makeDeclRef(synAccessorDecl))); + mapRequiredAccessorToSynAccessor.add(requiredAccessorDeclRef, synAccessorDecl); } - // The type of our synthesized property will be the same as the inner property. + // Once our synthesized declaration is complete, we need + // to install it as the witness that satifies the given + // requirement. // - auto propertyType = getType(m_astBuilder, as(innerProperty)); - synPropertyDecl->type.type = propertyType; - - // The visibility of synthesized decl should be the same as the inner requirement - if (innerProperty.getDecl()->findModifier()) + // Subsequent code generation should not be able to tell the + // difference between our synthetic property and a hand-written + // one with the same behavior. + // + for (auto& [key, value] : mapRequiredAccessorToSynAccessor) { - auto vis = getDeclVisibility(innerProperty.getDecl()); - addVisibilityModifier(m_astBuilder, synPropertyDecl, vis); + witnessTable->add(key.getDecl(), RequirementWitness(getDefaultDeclRef(value))); } - - context->parentDecl->addMember(synPropertyDecl); witnessTable->add(requiredMemberDeclRef.getDecl(), - RequirementWitness(makeDeclRef(synPropertyDecl))); + RequirementWitness(getDefaultDeclRef(synAccesorContainer))); return true; } - bool SemanticsVisitor::trySynthesizeAssociatedTypeRequirementWitness( + bool SemanticsVisitor::trySynthesizeWrapperTypeSubscriptRequirementWitness( ConformanceCheckingContext* context, - LookupResult const& inLookupResult, - DeclRef requiredMemberDeclRef, + DeclRef requiredMemberDeclRef, RefPtr witnessTable) { - SLANG_UNUSED(inLookupResult); + // We are synthesizing the subscript requirement for a wrapper type: + // struct Wrapper + // { + // Inner inner; + // subscript(int index)->int { get { return inner[index]; } + // set { inner[index] = newValue; } + // } + // } + // + // // Find the witness that FooImpl : IFoo. + auto aggTypeDecl = as(context->parentDecl); + auto innerType = aggTypeDecl->wrappedType.type; + DeclRef innerProperty; + auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); + if (!innerWitness) + return false; + // + List synArgs; + ThisExpr* synThis; + auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness(context, requiredMemberDeclRef, + synArgs, synThis); + auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); + synThis->checked = true; - // The only case we can synthesize for now is when the conformant type - // is a wrapper type. - if (!isWrapperTypeDecl(context->parentDecl)) + // Form a `this[args...]` expression that we will use to coerce from + // in the synthesized subscript accessors. + // + synSubscriptDecl->parentDecl = context->parentDecl; + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); + auto base = m_astBuilder->create(); + base->scope = synThis->scope; + base->name = getName("inner"); + + IndexExpr* indexExpr = m_astBuilder->create(); + indexExpr->baseExpression = base; + indexExpr->indexExprs = _Move(synArgs); + auto synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); + + if (tempSink.getErrorCount() != 0) return false; - auto aggTypeDecl = as(context->parentDecl); - auto lookupResult = lookUpMember( - m_astBuilder, - this, - requiredMemberDeclRef.getName(), - aggTypeDecl->wrappedType.type, - aggTypeDecl->ownedScope, - LookupMask::Default, - LookupOptions::IgnoreBaseInterfaces); - if (!lookupResult.isValid() || lookupResult.isOverloaded()) + + // Our synthesized subscript will have an accessor declaration for + // each accessor of the requirement. + // + bool canSynAccessors = synthesizeAccessorRequirements( + context, + requiredMemberDeclRef, + declType, + synBaseStorageExpr, + synSubscriptDecl, witnessTable); + if (!canSynAccessors) return false; - auto assocType = DeclRefType::create(m_astBuilder, lookupResult.item.declRef); - witnessTable->add(requiredMemberDeclRef.getDecl(), assocType); - for (auto typeConstraintDecl : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + + synSubscriptDecl->parentDecl = context->parentDecl; + + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier()) { - auto witness = tryGetSubtypeWitness(assocType, getSup(m_astBuilder, typeConstraintDecl)); - if (!witness) - return false; - witnessTable->add(typeConstraintDecl.getDecl(), witness); + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); } + return true; } - bool SemanticsVisitor::trySynthesizeAssociatedConstantRequirementWitness( + bool SemanticsVisitor::trySynthesizeSubscriptRequirementWitness( ConformanceCheckingContext* context, - LookupResult const& inLookupResult, - DeclRef requiredMemberDeclRef, + DeclRef requiredMemberDeclRef, RefPtr witnessTable) { - SLANG_UNUSED(inLookupResult); + if (isWrapperTypeDecl(context->parentDecl)) + return trySynthesizeWrapperTypeSubscriptRequirementWitness(context, requiredMemberDeclRef, witnessTable); - // The only case we can synthesize for now is when the conformant type - // is a wrapper type, i.e. - // struct Foo:IFoo = FooImpl; - if (!isWrapperTypeDecl(context->parentDecl)) - return false; + // The situation here is that the context of an inheritance + // declaration didn't provide an exact match for a required + // subscript. E.g.: + // + // interface ICell { subscript(int index)->int {get;} } + // struct MyCell : ICell + // { + // subscript(uint index)->int {ref;} + // } + // + // It is clear in this case that the `MyCell` type *can* + // satisfy the signature required by `ICell`, if we consider + // all the allowed type coercion rules, and use `ref` accessor + // to implement `get`. + // + // The approach in this function will be to construct a + // synthesized `subscript` along the lines of: + // + // struct MyCell ... + // { + // ... + // subscript(int index)->int {get;} + // { + // get { return this.origianl_subscript[index]; } + // } + // } + // + // That is, we construct a `subscript` with the correct type + // and with an accessor for each requirement, where the accesors + // all try to dispatch to the original subscript decl. + // + // If those synthesized accessors all type-check, then we can + // say that the type must satisfy the requirement structurally, + // even if there isn't an exact signature match. More + // importantly, the `property` we just synthesized can be + // used as a witness to the fact that the requirement is + // satisfied. + // + // The big-picture flow of the logic here is similar to + // `trySynthesizePropertyRequirementWitness()` above, and we + // will not comment this code as exhaustively, under the + // assumption that readers of the code don't benefit from + // having the exact same information stated twice. + // - // Find the witness that FooImpl : IFoo. - auto aggTypeDecl = as(context->parentDecl); - auto innerType = aggTypeDecl->wrappedType.type; - DeclRef innerProperty; - auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); - if (!innerWitness) + List synArgs; + ThisExpr* synThis; + auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness(context, requiredMemberDeclRef, + synArgs, synThis); + synThis->type.isLeftValue = true; + synThis->checked = true; + synSubscriptDecl->parentDecl = context->parentDecl; + + auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); + + // Form a `this[args...]` expression that we will use to coerce from + // in the synthesized subscript accessors. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); + IndexExpr* indexExpr = m_astBuilder->create(); + indexExpr->baseExpression = synThis; + indexExpr->indexExprs = _Move(synArgs); + auto synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); + + if (tempSink.getErrorCount() != 0) return false; - auto witness = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredMemberDeclRef.getDecl()); - if (witness.getFlavor() != RequirementWitness::Flavor::val) + // Our synthesized subscript will have an accessor declaration for + // each accessor of the requirement. + // + bool canSynAccessors = synthesizeAccessorRequirements( + context, + requiredMemberDeclRef, + declType, + synBaseStorageExpr, + synSubscriptDecl, witnessTable); + if (!canSynAccessors) return false; - witnessTable->add(requiredMemberDeclRef.getDecl(), witness.getVal()); + + + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier()) + { + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); + } + return true; } @@ -5314,6 +5536,14 @@ namespace Slang witnessTable); } + if (auto requiredSubscriptDeclRef = requiredMemberDeclRef.as()) + { + return trySynthesizeSubscriptRequirementWitness( + context, + requiredSubscriptDeclRef, + witnessTable); + } + if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as()) { if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier()) @@ -5852,9 +6082,28 @@ namespace Slang !(requiredMemberDeclRef.as() && getInner(requiredMemberDeclRef.as())->hasModifier())) { - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); - getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); - return false; + // If we failed to look up a member with the name of the + // requirement, it may be possible that we can still synthesis the + // implementation if this is one of the known builtin requirements. + // Otherwise, report diagnostic now. + + if (requiredMemberDeclRef.getDecl()->hasModifier() || + (requiredMemberDeclRef.as() && + getInner(requiredMemberDeclRef.as())->hasModifier())) + { + } + else if (requiredMemberDeclRef.as() && + (as(context->conformingType) || + as(context->conformingType) || + as(context->conformingType))) + { + } + else + { + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); + getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); + return false; + } } } @@ -10615,7 +10864,8 @@ namespace Slang bool isDiffParam = (!param->findModifier()); if (isDiffParam) { - if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) + auto diffPair = visitor->getDifferentialPairType(param->getType()); + if (auto pairType = as(diffPair)) { arg->type.type = pairType; arg->type.isLeftValue = true; @@ -10636,6 +10886,11 @@ namespace Slang direction = ParameterDirection::kParameterDirection_InOut; } } + else if (auto refPairType = as(diffPair)) + { + // no need to change direction of ref-pairs. + arg->type.type = refPairType; + } else { isDiffParam = false; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 5c3637b762..28ca8c98cc 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1140,7 +1140,8 @@ namespace Slang { if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier()) { - if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType) + if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType + || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType) { // We are trying to get differential type from a differential type. // The result is itself. @@ -1148,7 +1149,10 @@ namespace Slang } } type = resolveType(type); - if (const auto witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()))) + auto witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())); + if (!witness) + witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType())); + if (witness) { auto diffTypeLookupResult = lookUpMember( getASTBuilder(), @@ -1376,6 +1380,13 @@ namespace Slang { addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } + + if (auto subtypeWitness = as( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType()))) + { + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + } + if (auto aggTypeDeclRef = declRefType->getDeclRef().as()) { foreachDirectOrExtensionMemberOfType(this, aggTypeDeclRef, [&](DeclRef member) @@ -1410,7 +1421,15 @@ namespace Slang Expr* SemanticsVisitor::CheckTerm(Expr* term) { + // If we have already checked the expr, don't check again. + if (term->checked) + { + return term; + } + auto checkedTerm = _CheckTerm(term); + checkedTerm->checked = true; + // Differentiable type checking. // TODO: This can be super slow. if (this->m_parentFunc && @@ -2904,18 +2923,25 @@ namespace Slang return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); } } + // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); + auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType(); - auto conformanceWitness = as(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (conformanceWitness) + if (auto conformanceWitness = isTypeDifferentiable(primalType)) { - return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + if (conformanceWitness->getSup() == differentiableInterface) + { + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + } + else if (conformanceWitness->getSup() == differentiableRefInterface) + { + return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness); + } } - else - return primalType; + return primalType; } Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index e56082aab9..334737c97e 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1748,7 +1748,7 @@ namespace Slang void addModifiersToSynthesizedDecl( ConformanceCheckingContext* context, DeclRef requirement, - FunctionDeclBase* synthesized, + CallableDecl* synthesized, ThisExpr* &synThis); void addRequiredParamsToSynthesizedDecl( @@ -1756,9 +1756,9 @@ namespace Slang CallableDecl* synthesized, List& synArgs); - FunctionDeclBase* synthesizeMethodSignatureForRequirementWitness( + CallableDecl* synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, + DeclRef requiredMemberDeclRef, List& synArgs, ThisExpr*& synThis); @@ -1769,6 +1769,14 @@ namespace Slang List& synGenericArgs, ThisExpr*& synThis); + bool synthesizeAccessorRequirements( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + Type* resultType, + Expr* synBoundStorageExpr, + ContainerDecl* synAccesorContainer, + RefPtr witnessTable); + void _addMethodWitness( WitnessTable* witnessTable, DeclRef requirement, @@ -1806,6 +1814,16 @@ namespace Slang DeclRef requiredMemberDeclRef, RefPtr witnessTable); + bool trySynthesizeSubscriptRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool trySynthesizeWrapperTypeSubscriptRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + bool trySynthesizeAssociatedTypeRequirementWitness( ConformanceCheckingContext* context, LookupResult const& lookupResult, @@ -2203,7 +2221,7 @@ namespace Slang bool isValidGenericConstraintType(Type* type); - bool isTypeDifferentiable(Type* type); + SubtypeWitness* isTypeDifferentiable(Type* type); bool doesTypeHaveTag(Type* type, TypeTag tag); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 9b0b56ee24..8d89993ba2 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1200,28 +1200,6 @@ namespace Slang return parent; } - void countDistanceToGloablScope(DeclRef const& leftDecl, - DeclRef const& rightDecl, - int& leftDistance, int& rightDistance) - { - leftDistance = 0; - rightDistance = 0; - - DeclRef decl = leftDecl; - while(decl) - { - leftDistance++; - decl = decl.getParent(); - } - - decl = rightDecl; - while(decl) - { - rightDistance++; - decl = decl.getParent(); - } - } - // Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal. // int SemanticsVisitor::CompareLookupResultItems( @@ -1347,23 +1325,6 @@ namespace Slang } } - // We need to consider the distance of the declarations to the global scope to resolve this case: - // float f(float x); - // struct S - // { - // float f(float x); - // float g(float y) { return f(y); } // will call S::f() instead of ::f() - // } - // We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope, - // because this function will only choose the valid candidates. So if there is situation like this: - // void main() { S s; s.f(1.0);} or - // struct T { float g(y) { f(y); } }, there won't be ambiguity. - // So we just need to count which declaration is farther from the global scope and favor the farther one. - int leftDistance = 0; - int rightDistance = 0; - countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance); - if (leftDistance != rightDistance) - return leftDistance > rightDistance ? -1 : 1; // TODO: We should generalize above rules such that in a tie a declaration // A::m is better than B::m when all other factors are equal and @@ -1479,6 +1440,70 @@ namespace Slang return 0; } + int getScopeRank(DeclRef const& left, + DeclRef const& right, Slang::Scope* referenceSiteScope) + { + if (!referenceSiteScope) + return 0; + + DeclRef prefixDecl = referenceSiteScope->containerDecl; + + // Hold the path from reference site to the root + // key: Decl node, value: distance from reference site + Dictionary refPath; + for (auto node = prefixDecl; node != nullptr; node = node.getParent()) + { + Decl* key = node.getDecl(); + uint32_t value = (uint32_t)refPath.getCount(); + refPath.add(key, value); + } + + // find the common prefix decl of reference site and left + int leftDistance = 0; + int rightDistance = 0; + auto distanceToCommonPrefix = [](DeclRef const& candidate, Dictionary refPath) -> int + { + uint32_t distanceToReferenceSite = 0; + uint32_t distanceToCandidate = 0; + + // Sanity check + if (candidate.getDecl() == nullptr) + return -1; + + // search from candidate to root, once we found the first node in the reference path, that is the first + // common prefix, and we can stop searching. + for (auto node = candidate; node != nullptr; node = node.getParent()) + { + Decl* key = node.getDecl(); + if (refPath.tryGetValue(key, distanceToReferenceSite)) + { + break; + } + distanceToCandidate++; + } + + // If we don't find the common prefix, there must be something wrong, return the max value. + if (distanceToReferenceSite == 0) + return -1; + + return distanceToReferenceSite + distanceToCandidate; + }; + + leftDistance = distanceToCommonPrefix(left, refPath); + rightDistance = distanceToCommonPrefix(right, refPath); + + if (leftDistance == rightDistance) + return 0; + + if (leftDistance == -1) + return 1; + + if (rightDistance == -1) + return -1; + + return leftDistance < rightDistance ? -1 : 1; + } + int SemanticsVisitor::CompareOverloadCandidates( OverloadCandidate* left, OverloadCandidate* right) @@ -1553,6 +1578,7 @@ namespace Slang auto itemDiff = CompareLookupResultItems(left->item, right->item); if(itemDiff) return itemDiff; + auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item); if(specificityDiff) return specificityDiff; @@ -1562,6 +1588,42 @@ namespace Slang if (externExportDiff) return externExportDiff; + // We need to consider the distance of the declarations to the global scope to resolve this case: + // float f(float x); + // struct S + // { + // float f(float x); + // float g(float y) { return f(y); } // will call S::f() instead of ::f() + // } + // we will count the distance from the reference site to the declaration in the scope tree. + + // NOTE: We CAN'T do this for the generic function, because generic lookup is little bit complicated. + // It will go through multiple passes of candidates compare. + // In the first pass, it will lookup all the generic candidates that matches the generic parameter only, + // e.g., the following generic functions are totally different, but they will be selected as candidates + // because the function name and the generic parameters are the same: + // void func(Z0 a, Z1 b); + // void func(Z0 a, Z1 b, Z0 c); + // void func(Z0 a, Z1 b, Z0 c, Z1 d); + // + // So in this case, we should not consider the scope rank and overload rank at all, because there is only + // one of above candidates is valid, and the rank calculation doesn't consider the correctness of the + // candidates, so it could select the wrong candidate. + // + // In the next pass, the lookup system will match the input parameters in those candidates to find out the valid + // match, the "flavor" field will become "Func" or "Expr". So the rank calculation can be applied. + if (left->flavor == OverloadCandidate::Flavor::Generic || + left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric || + right->flavor == OverloadCandidate::Flavor::Generic || + right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric) + { + return 0; + } + + auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope); + if (scopeRank) + return scopeRank; + // If we reach here, we will attempt to use overload rank to break the ties. auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef); if (overloadRankDiff) diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 541085b4ee..c89d94c807 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2451,12 +2451,16 @@ namespace Slang return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr); } + bool CodeGenContext::shouldReportCheckpointIntermediates() + { + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ReportCheckpointIntermediates); + } + bool CodeGenContext::shouldDumpIntermediates() { return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates); } - bool CodeGenContext::shouldTrackLiveness() { return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 0c788ae182..4b20d1f763 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2728,6 +2728,7 @@ namespace Slang bool shouldValidateIR(); bool shouldDumpIR(); + bool shouldReportCheckpointIntermediates(); bool shouldTrackLiveness(); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 0c9f7acf8f..179c913ac5 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -895,6 +895,12 @@ DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage B DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.") +// Autodiff checkpoint reporting +DIAGNOSTIC(-1, Note, reportCheckpointIntermediates, "checkpointing context of $1 bytes associated with function: '$0'") +DIAGNOSTIC(-1, Note, reportCheckpointVariable, "$0 bytes ($1) used to checkpoint the following item:") +DIAGNOSTIC(-1, Note, reportCheckpointCounter, "$0 bytes ($1) used for a loop counter here:") +DIAGNOSTIC(-1, Note, reportCheckpointNone, "no checkpoint contexts to report") + // // 8xxxx - Issues specific to a particular library/technology/platform/etc. // diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index cdd2ca5b66..6e3556064e 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -34,6 +34,7 @@ #include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" +#include "slang-ir-layout.h" #include "slang-ir-legalize-array-return-type.h" #include "slang-ir-legalize-mesh-outputs.h" #include "slang-ir-legalize-varying-params.h" @@ -214,6 +215,68 @@ static void dumpIRIfEnabled( } } +static void reportCheckpointIntermediates(CodeGenContext* codeGenContext, DiagnosticSink* sink, IRModule* irModule) +{ + // Report checkpointing information + CompilerOptionSet& optionSet = codeGenContext->getTargetProgram()->getOptionSet(); + SourceManager* sourceManager = sink->getSourceManager(); + + SourceWriter typeWriter(sourceManager, LineDirectiveMode::None, nullptr); + + CLikeSourceEmitter::Desc description; + description.codeGenContext = codeGenContext; + description.sourceWriter = &typeWriter; + + CPPSourceEmitter emitter(description); + + int nonEmptyStructs = 0; + for (auto inst : irModule->getGlobalInsts()) + { + IRStructType *structType = as(inst); + if (!structType) + continue; + + auto checkpointDecoration = structType->findDecoration(); + if (!checkpointDecoration) + continue; + + IRSizeAndAlignment structSize; + getNaturalSizeAndAlignment(optionSet, structType, &structSize); + + // Reporting happens before empty structs are optimized out + // and we still want to keep the checkpointing decorations, + // so we end up needing to check for non-zero-ness + if (structSize.size == 0) + continue; + + auto func = checkpointDecoration->getSourceFunction(); + sink->diagnose(structType, Diagnostics::reportCheckpointIntermediates, func, structSize.size); + nonEmptyStructs++; + + for (auto field : structType->getFields()) + { + IRType *fieldType = field->getFieldType(); + IRSizeAndAlignment fieldSize; + getNaturalSizeAndAlignment(optionSet, fieldType, &fieldSize); + if (fieldSize.size == 0) + continue; + + typeWriter.clearContent(); + emitter.emitType(fieldType); + + sink->diagnose(field->sourceLoc, + field->findDecoration() + ? Diagnostics::reportCheckpointCounter + : Diagnostics::reportCheckpointVariable, + fieldSize.size, + typeWriter.getContent()); + } + } + + if (nonEmptyStructs == 0) + sink->diagnose(SourceLoc(), Diagnostics::reportCheckpointNone); +} + struct LinkingAndOptimizationOptions { bool shouldLegalizeExistentialAndResourceTypes = true; @@ -767,6 +830,10 @@ Result linkAndOptimizeIR( break; } + // Report checkpointing information + if (codeGenContext->shouldReportCheckpointIntermediates()) + reportCheckpointIntermediates(codeGenContext, sink, irModule); + if (requiredLoweringPassSet.autodiff) finalizeAutoDiffPass(targetProgram, irModule); diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 8a48936d7e..b55f6b93d8 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -69,30 +69,28 @@ struct AddressInstEliminationContext } } - void transformLoadAddr(IRUse* use) + void transformLoadAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto load = as(use->getUser()); - IRBuilder builder(module); builder.setInsertBefore(use->getUser()); auto value = getValue(builder, addr); load->replaceUsesWith(value); load->removeAndDeallocate(); } - void transformStoreAddr(IRUse* use) + void transformStoreAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto store = as(use->getUser()); - IRBuilder builder(module); builder.setInsertBefore(use->getUser()); storeValue(builder, addr, store->getVal()); store->removeAndDeallocate(); } - void transformCallAddr(IRUse* use) + void transformCallAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto call = as(use->getUser()); @@ -103,7 +101,6 @@ struct AddressInstEliminationContext return; } - IRBuilder builder(module); builder.setInsertBefore(call); auto tempVar = builder.emitVar(cast(addr->getFullType())->getValueType()); @@ -155,17 +152,20 @@ struct AddressInstEliminationContext use = nextUse; continue; } + + IRBuilder transformBuilder(module); + IRBuilderSourceLocRAII sourceLocationScope(&transformBuilder, use->getUser()->sourceLoc); switch (use->getUser()->getOp()) { case kIROp_Load: - transformLoadAddr(use); + transformLoadAddr(transformBuilder, use); break; case kIROp_Store: - transformStoreAddr(use); + transformStoreAddr(transformBuilder, use); break; case kIROp_Call: - transformCallAddr(use); + transformCallAddr(transformBuilder, use); break; case kIROp_GetElementPtr: case kIROp_FieldAddress: diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fe7c77ba06..609bcd8a33 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -336,8 +336,8 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig { auto origPtr = origLoad->getPtr(); auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr); - auto primalPtrType = as(primalPtr->getFullType()); - if (primalPtrType) + + if (auto primalPtrType = as(primalPtr->getFullType())) { if (auto diffPairType = as(primalPtrType->getValueType())) { @@ -355,6 +355,18 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load); return InstPair(primalElement, diffElement); } + else if (auto diffPtrPairType = as(primalPtrType->getValueType())) + { + auto load = builder->emitLoad(primalPtr); + builder->markInstAsPrimal(load); + + auto primalElement = builder->emitDifferentialPtrPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPtrPairGetDifferential( + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPtrPairType), load); + builder->markInstAsPrimal(primalElement); + builder->markInstAsPrimal(diffElement); + return InstPair(primalElement, diffElement); + } } auto primalLoad = maybeCloneForPrimalInst(builder, origLoad); @@ -387,6 +399,16 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or auto store = builder->emitStore(primalStoreLocation, valToStore); builder->markInstAsMixedDifferential(store, diffPairType); + return InstPair(store, nullptr); + } + else if (auto diffRefPairType = as(primalLocationPtrType->getValueType())) + { + auto valToStore = builder->emitMakeDifferentialPtrPair(diffRefPairType, primalStoreVal, diffStoreVal); + builder->markInstAsPrimal(valToStore); + + auto store = builder->emitStore(primalStoreLocation, valToStore); + builder->markInstAsPrimal(store); + return InstPair(store, nullptr); } } @@ -404,7 +426,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or // Default case, storing the entire type (and not a member) diffStore = as( builder->emitStore(diffStoreLocation, diffStoreVal)); - + markDiffTypeInst(builder, diffStore, primalStoreVal->getDataType()); return InstPair(primalStore, diffStore); } @@ -696,14 +718,16 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) { auto pairPtrType = as(pairType); - auto pairValType = as( + + auto pairValType = as( pairPtrType ? pairPtrType->getValueType() : pairType); + auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType); if (auto ptrParamType = as(diffParamType)) { // Create temp var to pass in/out arguments. auto srcVar = argBuilder.emitVar(pairValType); - argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType()); + markDiffPairTypeInst(&argBuilder, srcVar, pairValType); auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); if (ptrParamType->getOp() == kIROp_InOutType) @@ -716,28 +740,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig else { diffArgVal = argBuilder.emitLoad(diffArg); - argBuilder.markInstAsDifferential(diffArgVal, pairValType->getValueType()); + markDiffTypeInst(&argBuilder, diffArgVal, pairValType->getValueType()); } auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal); - argBuilder.markInstAsMixedDifferential(initVal, primalType); + markDiffPairTypeInst(&argBuilder, initVal, pairValType); auto store = argBuilder.emitStore(srcVar, initVal); - argBuilder.markInstAsMixedDifferential(store, primalType); + markDiffPairTypeInst(&argBuilder, store, pairValType); } if (as(ptrParamType)) { // Read back new value. auto newVal = afterBuilder.emitLoad(srcVar); - afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType()); + markDiffPairTypeInst(&afterBuilder, newVal, pairValType); auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal); afterBuilder.emitStore(primalArg, newPrimalVal); if (diffArg) { auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal); - afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType()); + markDiffTypeInst(&afterBuilder, newDiffVal, pairValType->getValueType()); auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal); - afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType()); + markDiffTypeInst(&afterBuilder, storeInst, pairValType->getValueType()); } } args.add(srcVar); @@ -753,7 +777,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig SLANG_RELEASE_ASSERT(diffArg); auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg); - argBuilder.markInstAsMixedDifferential(diffPair, pairType); + markDiffPairTypeInst(&argBuilder, diffPair, pairType); args.add(diffPair); continue; @@ -779,12 +803,13 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig diffCallee, args); placeholderCall->removeAndDeallocate(); + argBuilder.markInstAsMixedDifferential(callInst, diffReturnType); argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee); *builder = afterBuilder; - if (diffReturnType->getOp() == kIROp_DifferentialPairType) + if (as(diffReturnType) || as(diffReturnType)) { IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst); auto diffType = differentiateType(&afterBuilder, origCall->getFullType()); @@ -1751,12 +1776,13 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr IRInst* valToStore = nullptr; if (writeBack.value.differential) { + auto pairValType = cast(param->getFullType())->getValueType(); auto diffVal = builder.emitLoad(writeBack.value.differential); - builder.markInstAsDifferential(diffVal, primalVal->getFullType()); + markDiffTypeInst(&builder, diffVal, primalVal->getFullType()); - valToStore = builder.emitMakeDifferentialPair(cast(param->getFullType())->getValueType(), - primalVal, diffVal); - builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType()); + valToStore = builder.emitMakeDifferentialPair(pairValType, primalVal, diffVal); + + markDiffPairTypeInst(&builder, valToStore, pairValType); } else { @@ -1767,7 +1793,7 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr if (writeBack.value.differential) { - builder.markInstAsMixedDifferential(storeInst, valToStore->getFullType()); + markDiffPairTypeInst(&builder, storeInst, valToStore->getFullType()); } } } @@ -2043,24 +2069,25 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam SLANG_ASSERT(diffPairParam); - if (auto pairType = as(diffPairType)) + if (as(diffPairType) || as(diffPairType)) { return InstPair( builder->emitDifferentialPairGetPrimal(diffPairParam), builder->emitDifferentialPairGetDifferential( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, pairType), + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( + builder, + as(diffPairType)), diffPairParam)); } else if (auto pairPtrType = as(diffPairType)) { - auto ptrInnerPairType = as(pairPtrType->getValueType()); + auto ptrInnerPairType = as(pairPtrType->getValueType()); // Make a local copy of the parameter for primal and diff parts. auto primal = builder->emitVar(ptrInnerPairType->getValueType()); auto diffType = differentiateType(builder, cast(origParam->getDataType())->getValueType()); auto diff = builder->emitVar(diffType); - builder->markInstAsDifferential( - diff, builder->getPtrType(ptrInnerPairType->getValueType())); + markDiffTypeInst(builder, diff, builder->getPtrType(ptrInnerPairType->getValueType())); IRInst* primalInitVal = nullptr; IRInst* diffInitVal = nullptr; @@ -2072,17 +2099,18 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam else { auto initVal = builder->emitLoad(diffPairParam); - builder->markInstAsMixedDifferential(initVal, ptrInnerPairType); + markDiffPairTypeInst(builder, initVal, ptrInnerPairType); primalInitVal = builder->emitDifferentialPairGetPrimal(initVal); diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal); } - builder->markInstAsDifferential(diffInitVal, ptrInnerPairType->getValueType()); + markDiffTypeInst(builder, diffInitVal, ptrInnerPairType->getValueType()); + builder->emitStore(primal, primalInitVal); auto diffStore = builder->emitStore(diff, diffInitVal); - builder->markInstAsDifferential(diffStore, ptrInnerPairType->getValueType()); + markDiffTypeInst(builder, diffStore, ptrInnerPairType->getValueType()); mapInOutParamToWriteBackValue[diffPairParam] = InstPair(primal, diff); return InstPair(primal, diff); diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index 7fc8ebbe65..3a6d52bead 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -107,10 +107,13 @@ struct DiffPairLoweringPass : InstPassBase case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferentialUserCode: case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetDifferential: + case kIROp_DifferentialPtrPairGetPrimal: lowerPairAccess(builder, inst); break; case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: lowerMakePair(builder, inst); break; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 9fe4ec70b6..2881abe3eb 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -3,8 +3,9 @@ #include "slang-ir-autodiff-region.h" #include "slang-ir-simplify-cfg.h" #include "slang-ir-util.h" -#include "../core/slang-func-ptr.h" +#include "slang-ir-insts.h" #include "slang-ir.h" +#include "../core/slang-func-ptr.h" namespace Slang { @@ -891,6 +892,16 @@ void applyToInst( } } SLANG_ASSERT(replacement); + + // If the replacement and inst are not the exact same type, use an int-cast + // (e.g. uint vs. int) + // + if (replacement->getDataType() != inst->getDataType()) + { + setInsertAfterOrdinaryInst(builder, replacement); + replacement = builder->emitCast(inst->getDataType(), replacement); + } + cloneCtx->cloneEnv.mapOldValToNew[inst] = replacement; cloneCtx->registerClonedInst(builder, inst, replacement); return; @@ -1092,7 +1103,8 @@ IRType* getTypeForLocalStorage( IRVar* emitIndexedLocalVar( IRBlock* varBlock, IRType* baseType, - const List& defBlockIndices) + const List& defBlockIndices, + SourceLoc location) { // Cannot store pointers. Case should have been handled by now. SLANG_RELEASE_ASSERT(!as(baseType)); @@ -1101,6 +1113,8 @@ IRVar* emitIndexedLocalVar( SLANG_RELEASE_ASSERT(!as(baseType)); IRBuilder varBuilder(varBlock->getModule()); + IRBuilderSourceLocRAII sourceLocationScope(&varBuilder, location); + varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst()); IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices); @@ -1179,9 +1193,14 @@ IRVar* storeIndexedValue( IRInst* instToStore, const List& defBlockIndices) { - IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices); + IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, + instToStore->getDataType(), + defBlockIndices, + instToStore->sourceLoc); - IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices); + IRInst* addr = emitIndexedStoreAddressForVar(builder, + localVar, + defBlockIndices); builder->emitStore(addr, instToStore); @@ -1574,12 +1593,16 @@ RefPtr ensurePrimalAvailability( // region, that means there's no need to allocate a fully indexed var. // defBlockIndices = maybeTrimIndices(defBlockIndices, indexedBlockInfo, outOfScopeUses); - - IRVar* localVar = storeIndexedValue( - &builder, - varBlock, - builder.emitLoad(varToStore), - defBlockIndices); + + IRVar* localVar = nullptr; + { + IRBuilderSourceLocRAII sourceLocationScope(&builder, varToStore->sourceLoc); + localVar = storeIndexedValue( + &builder, + varBlock, + builder.emitLoad(varToStore), + defBlockIndices); + } for (auto use : outOfScopeUses) { @@ -1626,6 +1649,8 @@ RefPtr ensurePrimalAvailability( } else { + IRBuilderSourceLocRAII sourceLocationScope(&builder, instToStore->sourceLoc); + // Handle the special case of loop counters. // The only case where there will be a reference of primal loop counter from rev blocks // is the start of a loop in the reverse code. Since loop counters are not considered a @@ -1643,6 +1668,8 @@ RefPtr ensurePrimalAvailability( setInsertAfterOrdinaryInst(&builder, instToStore); auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices); + if (isLoopCounter) + builder.addLoopCounterDecoration(localVar); for (auto use : outOfScopeUses) { @@ -1728,6 +1755,8 @@ static IRBlock* getUpdateBlock(IRLoop* loop) void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam) { IRBuilder builder(primalLoop); + IRBuilderSourceLocRAII sourceLocationScope(&builder, primalLoop->sourceLoc); + primalCountParam = nullptr; // Grab first primal block. @@ -1899,8 +1928,7 @@ RefPtr applyCheckpointPolicy(IRGlobalValueWithCode* func) // Legalize the primal inst accesses by introducing local variables / arrays and emitting // necessary load/store logic. // - primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); - return primalsInfo; + return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); } void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) @@ -1980,6 +2008,7 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_MakeArrayFromElement: case kIROp_MakeDifferentialPair: case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: case kIROp_MakeOptionalNone: case kIROp_MakeOptionalValue: case kIROp_MakeExistential: @@ -1987,6 +2016,8 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferentialUserCode: case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetDifferential: + case kIROp_DifferentialPtrPairGetPrimal: case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialWitnessTable: diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 35a197f29b..169dd31ee4 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -152,6 +152,16 @@ namespace Slang builder->emitBlock(); params = _defineFuncParams(builder, as(existingPrimalFunc)); params.removeLast(); + + // Unwrap any ref pairs. We need this special case for trivial funcs. + for (Int i = 0; i < params.getCount(); i++) + { + if (as(params[i]->getDataType())) + { + params[i] = builder->emitDifferentialPtrPairGetPrimal(params[i]); + } + } + IRInst* originalFuncRefFromPrimalFunc = originalFunc; if (originalGeneric) originalFuncRefFromPrimalFunc = maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric); @@ -266,7 +276,20 @@ namespace Slang if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) return primalNoDiffType; - return (IRType*)findOrTranscribePrimalInst(builder, paramType); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); + + // Differentiable pointer types are treated as primal pairs, since they aren't involved in the transposition + // process. + // + if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) + { + auto diffPairType = tryGetDiffPairType(builder, primalType); + SLANG_ASSERT(diffPairType); + + return diffPairType; + } + + return primalType; } IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType) @@ -292,7 +315,7 @@ namespace Slang auto diffPairType = tryGetDiffPairType(builder, paramType); if (diffPairType) { - if (!as(diffPairType)) + if (!as(diffPairType) && !as(diffPairType)) return builder->getInOutType(diffPairType); return diffPairType; } @@ -403,8 +426,11 @@ namespace Slang List primalTypes, propagateTypes; IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType()); + IRParam *currentParam = origFunc->getFirstParam(); for (UInt i = 0; i < origFuncType->getParamCount(); i++) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, currentParam->sourceLoc); + auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i)); auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i)); if (propagateParamType) @@ -453,6 +479,7 @@ namespace Slang primalArgs.add(var); } primalTypes.add(primalParamType); + currentParam = currentParam->getNextParam(); } // Add dOut argument to propagateArgs. @@ -588,6 +615,8 @@ namespace Slang autoDiffSharedContext->transcriberSet.forwardTranscriber); auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); IRFunc* fwdDiffFunc = as(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent))); + fwdDiffFunc->sourceLoc = primalFunc->sourceLoc; + SLANG_ASSERT(fwdDiffFunc); auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); for (auto i = oldCount; i < newCount; i++) @@ -712,8 +741,10 @@ namespace Slang } // Transpose the first block (parameter block) - auto paramTransposeInfo = - splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable); + auto paramTransposeInfo = splitAndTransposeParameterBlock(builder, + diffPropagateFunc, + primalFunc->sourceLoc, + isResultDifferentiable); // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc // may be used by write back logic that we are going to insert later. @@ -815,6 +846,7 @@ namespace Slang ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, + SourceLoc primalLoc, bool isResultDifferentiable) { // This method splits transposes the all the parameters for both the primal and propagate computation. @@ -841,6 +873,7 @@ namespace Slang auto nextBlockBuilder = *builder; nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst()); + SourceLoc returnLoc; IRBlock* firstDiffBlock = nullptr; for (auto block : diffFunc->getBlocks()) { @@ -849,6 +882,13 @@ namespace Slang firstDiffBlock = block; break; } + + auto terminator = block->getTerminator(); + if (as(terminator)) + { + returnLoc = terminator->sourceLoc; + break; + } } SLANG_RELEASE_ASSERT(firstDiffBlock); @@ -895,6 +935,8 @@ namespace Slang // from the primal compuation logic in the future propagate function be replaced to. for (auto fwdParam : fwdParams) { + IRBuilderSourceLocRAII sourceLocationScope(builder, fwdParam->sourceLoc); + // Define the replacement insts that we are going to fill in for each case. IRInst* diffRefReplacement = nullptr; IRInst* primalRefReplacement = nullptr; @@ -942,7 +984,7 @@ namespace Slang // Initialize the var with input diff param at start. // Note that we insert the store in the primal block so it won't get transposed. auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam); - nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType); + nextBlockBuilder.markInstAsDifferential(storeInst, primalType); // Since this store inst is specific to propagate function, we track it in a // set so we can remove it when we generate the primal func. result.propagateFuncSpecificPrimalInsts.add(storeInst); @@ -1186,6 +1228,7 @@ namespace Slang SLANG_ASSERT(dOutParamType); dOutParam = builder->emitParam(dOutParamType); + dOutParam->sourceLoc = returnLoc; builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut")); result.propagateFuncParams.add(dOutParam); } @@ -1196,6 +1239,10 @@ namespace Slang result.primalFuncParams.add(ctxParam); result.propagateFuncParams.add(ctxParam); result.dOutParam = dOutParam; + + diffFunc->sourceLoc = primalLoc; + ctxParam->sourceLoc = primalLoc; + return result; } diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 68cb4e0c9a..b65701a7a9 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -105,6 +105,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase ParameterBlockTransposeInfo splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, + SourceLoc primalLoc, bool isResultDifferentiable); void writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index da69ed8aea..2141837b53 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -174,45 +174,54 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); -IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) +IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind) { - return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType); + if (kind == DiffConformanceKind::Any) + { + if (auto valueWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Value)) + return valueWitness; + if (auto ptrWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Ptr)) + return ptrWitness; + } + else + { + return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, kind); + } + return nullptr; } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) { - return builder->getDifferentialPairType( - (IRType*)primalType, - witness); + auto conformanceType = differentiableTypeConformanceContext.getConformanceTypeFromWitness(witness); + if (autoDiffSharedContext->isInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiableInterfaceType) + { + return builder->getDifferentialPairType((IRType*)primalType, witness); + } + else if (autoDiffSharedContext->isPtrInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiablePtrInterfaceType) + { + return builder->getDifferentialPtrPairType((IRType*)primalType, witness); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType) { - auto primalType = lookupPrimalInst(builder, originalType, nullptr); + auto primalType = lookupPrimalInst(builder, originalType, originalType); SLANG_RELEASE_ASSERT(primalType); IRInst* witness = nullptr; - if (auto lookup = as(primalType)) - { - if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey) - { - witness = builder->emitLookupInterfaceMethodInst( - lookup->getWitnessTable()->getDataType(), - lookup->getWitnessTable(), - autoDiffSharedContext->differentialAssocTypeWitnessStructKey); - } - } - - // Obtain the witness that primalType conforms to IDifferentiable. + + // Obtain the witness that primalType conforms to IDifferentiable/IDifferentiablePtrType if (!witness) - witness = tryGetDifferentiableWitness(builder, originalType); + witness = tryGetDifferentiableWitness(builder, primalType, DiffConformanceKind::Any); SLANG_RELEASE_ASSERT(witness); - auto pairType = builder->getDifferentialPairType( - (IRType*)primalType, - witness); - - return pairType; + return getOrCreateDiffPairType(builder, primalType, witness); } IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) @@ -223,8 +232,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o // Special-case for differentiable existential types. if (as(origType) || as(origType)) { - if (differentiableTypeConformanceContext.lookUpConformanceForType(origType)) + if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Value)) return autoDiffSharedContext->differentiableInterfaceType; + else if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Ptr)) + return autoDiffSharedContext->differentiablePtrInterfaceType; else return nullptr; } @@ -278,8 +289,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } case kIROp_DifferentialPairType: + case kIROp_DifferentialPtrPairType: { - auto primalPairType = as(primalType); + auto primalPairType = as(primalType); return getOrCreateDiffPairType( builder, differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), @@ -445,8 +457,24 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* auto interfaceType = as(unwrapAttributedType(origType->getOperand(0)->getDataType())); if (!interfaceType) return nullptr; - List lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath( + + List lookupKeyPath; + IRStructKey* diffStructKey = nullptr; + + List lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath( autoDiffSharedContext->differentiableInterfaceType, interfaceType); + if (lookupPathValueType.getCount() > 0) + { + lookupKeyPath = lookupPathValueType; + diffStructKey = autoDiffSharedContext->differentialAssocTypeStructKey; + } + else + { + // Try IDifferentiablePtrType + lookupKeyPath = differentiableTypeConformanceContext.findInterfaceLookupPath( + autoDiffSharedContext->differentiablePtrInterfaceType, interfaceType); + diffStructKey = autoDiffSharedContext->differentialAssocRefTypeStructKey; + } if (lookupKeyPath.getCount()) { @@ -456,7 +484,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* { outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey()); } - auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey); + auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, diffStructKey); return (IRType*)diffType; } return nullptr; @@ -559,12 +587,33 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui builder->markInstAsPrimal(primalDiffType); builder->markInstAsPrimal(diffWitness); + return InstPair(primal, diffWitness); + } + else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiablePtrInterfaceType) + { + auto primalDiffType = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + primal, + autoDiffSharedContext->differentialAssocRefTypeStructKey); + auto diffWitness = builder->emitLookupInterfaceMethodInst( + (IRType*)primalDiffType, + primal, + autoDiffSharedContext->differentialAssocRefTypeWitnessStructKey); + + // Mark both as primal since we're working with types + // (which don't need transposing) + // + builder->markInstAsPrimal(primalDiffType); + builder->markInstAsPrimal(diffWitness); + return InstPair(primal, diffWitness); } } + auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); + if (!decor) { return InstPair(primal, nullptr); @@ -589,6 +638,10 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType( { originalType = (IRType*)unwrapAttributedType(originalType); auto primalType = (IRType*)lookupPrimalInst(builder, originalType); + + // Can't generate zero for differentiable ptr types. Should never hit this case. + SLANG_ASSERT(!differentiableTypeConformanceContext.isDifferentiablePtrType(originalType)); + if (auto diffType = differentiateType(builder, originalType)) { IRInst* diffWitnessTable = nullptr; @@ -985,7 +1038,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst && !as(pair.differential)) { auto primalType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + markDiffTypeInst(builder, pair.differential, primalType); } } else @@ -997,7 +1050,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst if (as(pair.differential)) break; auto mixedType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsMixedDifferential(pair.primal, mixedType); + markDiffPairTypeInst(builder, pair.primal, mixedType); } } @@ -1033,8 +1086,9 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori if (as(origInst->getParent()) && !as(origInst)) return InstPair(origInst, nullptr); - auto result = transcribeInstImpl(builder, origInst); + IRBuilderSourceLocRAII sourceLocationScope(builder, origInst->sourceLoc); + auto result = transcribeInstImpl(builder, origInst); if (result.primal == nullptr && result.differential == nullptr) { if (auto origType = as(origInst)) @@ -1075,4 +1129,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori return result; } + +void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType) +{ + // Ignore module-level insts. + if (as(diffInst->getParent())) + return; + + // Also ignore generic-container-level insts. + if (as(diffInst->getParent()) && + as(diffInst->getParent()->getParent())) + return; + + // TODO: This logic is a bit of a hack. We need to determine if the type is + // relevant to ptr-type computation or not, or more complex applications + // that use dynamic dispatch + ptr types will fail. + // + if (as(diffInst)) + { + builder->markInstAsDifferential(diffInst, nullptr); + return; + } + + SLANG_ASSERT(diffInst); + SLANG_ASSERT(primalType); + + if (differentiableTypeConformanceContext.isDifferentiableValueType(primalType)) + { + builder->markInstAsDifferential(diffInst, primalType); + } + else if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) + { + builder->markInstAsPrimal(diffInst); + } + else + { + // Stop-gap solution to go with differential inst for now. + builder->markInstAsDifferential(diffInst, primalType); + } +} + +void AutoDiffTranscriberBase::markDiffPairTypeInst(IRBuilder* builder, IRInst* diffPairInst, IRType* pairType) +{ + SLANG_ASSERT(diffPairInst); + SLANG_ASSERT(pairType); + SLANG_ASSERT(as(pairType)); + + if (as(pairType)) + { + builder->markInstAsMixedDifferential(diffPairInst, pairType); + } + else if (as(pairType)) + { + builder->markInstAsPrimal(diffPairInst); + } + else + { + SLANG_UNEXPECTED("unexpected differentiable type"); + } +} + } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index f7f2dd6f20..9f3cfe56f0 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -91,7 +91,7 @@ struct AutoDiffTranscriberBase void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); - IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); @@ -152,6 +152,10 @@ struct AutoDiffTranscriberBase virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0; virtual IROp getInterfaceRequirementDerivativeDecorationOp() = 0; + + void markDiffTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType); + + void markDiffPairTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType); }; } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index d42462e1ba..8669df5a41 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -609,6 +609,8 @@ struct DiffTransposePass auto nextInst = inst->getNextInst(); if (auto varInst = as(inst)) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, varInst->sourceLoc); + if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst)) { if (auto ptrPrimalType = as(tryGetPrimalTypeFromDiffInst(varInst))) @@ -692,7 +694,11 @@ struct DiffTransposePass SLANG_ASSERT(lastRevBlock->getTerminator() == nullptr); builder.setInsertInto(lastRevBlock); - builder.emitReturn(); + + { + IRBuilderSourceLocRAII sourceLocationScope(&builder, revDiffFunc->sourceLoc); + builder.emitReturn(); + } // Remove fwd-mode blocks. for (auto block : workList) @@ -703,6 +709,8 @@ struct DiffTransposePass IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst) { + IRBuilderSourceLocRAII sourceLocationScope(builder, fwdInst->sourceLoc); + if (auto accVar = getOrCreateAccumulatorVar(fwdInst)) { auto gradValue = builder->emitLoad(accVar); @@ -731,6 +739,7 @@ struct DiffTransposePass return revAccumulatorVarMap[fwdInst]; IRBuilder tempVarBuilder(autodiffContext->moduleInst->getModule()); + IRBuilderSourceLocRAII sourceLocationSCope(&tempVarBuilder, fwdInst->sourceLoc); IRBlock* firstDiffBlock = firstRevDiffBlockMap[as(fwdInst->getParent()->getParent())]; @@ -785,6 +794,8 @@ struct DiffTransposePass for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) { auto arg = branchInst->getArg(ii); + + IRBuilderSourceLocRAII sourceLocationScope(&builder, arg->sourceLoc); if (isDifferentialInst(arg)) { // If the arg is a differential, emit a parameter @@ -885,6 +896,8 @@ struct DiffTransposePass List phiParamRevGradInsts; for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam()) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, param->sourceLoc); + if (isDifferentialInst(param)) { // This param might be used outside this block. @@ -949,6 +962,8 @@ struct DiffTransposePass if (auto accVar = getOrCreateAccumulatorVar(externInst)) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, externInst->sourceLoc); + // Accumulate all gradients, including our accumulator variable, // into one inst. // @@ -1050,6 +1065,7 @@ struct DiffTransposePass // Emit the aggregate of all the gradients here. // This will form the total derivative for this inst. + IRBuilderSourceLocRAII sourceLocationScope(builder, inst->sourceLoc); auto revValue = emitAggregateValue(builder, primalType, gradients); auto transposeResult = transposeInst(builder, inst, revValue); @@ -2100,7 +2116,8 @@ struct DiffTransposePass // If we reach this point, revValue must be a differentiable type. auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness( builder, - primalType); + primalType, + DiffConformanceKind::Value); SLANG_ASSERT(revTypeWitness); auto baseExistential = fwdInst->getOperand(0); @@ -2738,7 +2755,6 @@ struct DiffTransposePass gradient.revGradInst, gradient.fwdGradInst )); - } for (auto pair : bucketedGradients) diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 9b3e3a324a..507a2bf92d 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -75,6 +75,9 @@ struct ExtractPrimalFuncContext builder.setInsertBefore(destFunc); IRFuncType* originalFuncType = nullptr; outIntermediateType = createIntermediateType(destFunc); + + builder.addCheckpointIntermediateDecoration(outIntermediateType, originalFunc); + outIntermediateType->sourceLoc = originalFunc->sourceLoc; GenericChildrenMigrationContext migrationContext; migrationContext.init(as(findOuterGeneric(originalFunc)), as(findOuterGeneric(destFunc)), destFunc); @@ -141,7 +144,10 @@ struct ExtractPrimalFuncContext } auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); - if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness(&genTypeBuilder, (IRType*)fieldType)) + if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness( + &genTypeBuilder, + (IRType*)fieldType, + DiffConformanceKind::Value)) { genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, witness); } @@ -154,6 +160,7 @@ struct ExtractPrimalFuncContext IRInst* intermediateOutput) { auto field = addIntermediateContextField(inst->getDataType(), intermediateOutput); + field->sourceLoc = inst->sourceLoc; auto key = field->getKey(); if (auto nameHint = inst->findDecoration()) cloneDecoration(nameHint, key); @@ -219,6 +226,10 @@ struct ExtractPrimalFuncContext if (inst->hasUses()) { auto field = addIntermediateContextField(cast(inst->getDataType())->getValueType(), outIntermediary); + field->sourceLoc = inst->sourceLoc; + if (inst->findDecoration()) + builder.addLoopCounterDecoration(field); + builder.setInsertBefore(inst); auto fieldAddr = builder.emitFieldAddress( inst->getFullType(), outIntermediary, field->getKey()); @@ -379,12 +390,16 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( use->set(builder.getVoidValue()); continue; } + + IRBuilderSourceLocRAII sourceLocationScope(&builder, use->getUser()->sourceLoc); + builder.setInsertBefore(use->getUser()); auto valType = cast(inst->getFullType())->getValueType(); auto val = builder.emitFieldExtract( valType, intermediateVar, structKeyDecor->getStructKey()); + if (use->getUser()->getOp() == kIROp_Load) { use->getUser()->replaceUsesWith(val); @@ -392,8 +407,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } else { - auto tempVar = - builder.emitVar(valType); + auto tempVar = builder.emitVar(valType); builder.emitStore(tempVar, val); use->set(tempVar); } @@ -401,7 +415,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } else { - // Orindary value. + // Ordinary value. // We insert a fieldExtract at each use site instead of before `inst`, // since at this stage of autodiff pass, `inst` does not necessarily // dominate all the use sites if `inst` is defined in partial branch @@ -417,6 +431,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( inst->getFullType(), intermediateVar, structKeyDecor->getStructKey()); + val->sourceLoc = user->sourceLoc; builder.replaceOperand(iuse, val); } } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 9f18db6e06..6ae5126f9b 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -588,7 +588,6 @@ struct DiffUnzipPass as(diffMap[targetBlock]), diffArgs.getCount(), diffArgs.getBuffer())); - } case kIROp_conditionalBranch: @@ -710,6 +709,9 @@ struct DiffUnzipPass void splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst) { + IRBuilderSourceLocRAII primalLocationScope(primalBuilder, inst->sourceLoc); + IRBuilderSourceLocRAII diffLocationScope(diffBuilder, inst->sourceLoc); + auto instPair = _splitMixedInst(primalBuilder, diffBuilder, inst); primalMap[inst] = instPair.primal; diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 0979c097c4..94a605a688 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -25,7 +25,7 @@ bool isBackwardDifferentiableFunc(IRInst* func) return false; } -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey, IRType* resultType = nullptr) { if (auto witnessTable = as(witness)) { @@ -53,15 +53,16 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK } else { + SLANG_ASSERT(resultType); return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), + resultType, witness, requirementKey); } return nullptr; } -static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witness = type->getWitness(); SLANG_RELEASE_ASSERT(witness); @@ -70,16 +71,48 @@ static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRB if (as(type->getValueType()) || as(type->getValueType())) { // The differential type is the IDifferentiable interface type. - return sharedContext->differentiableInterfaceType; + if (as(type) || as(type)) + return sharedContext->differentiableInterfaceType; + else if (as(type)) + return sharedContext->differentiablePtrInterfaceType; + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } - return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); + if (as(type) || as(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocTypeStructKey, + builder->getTypeKind()); + else if (as(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocRefTypeStructKey, + builder->getTypeKind()); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); + + if (as(type) || as(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType); + else if (as(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } bool isNoDiffType(IRType* paramType) @@ -320,6 +353,24 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( return result; } +IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst) +{ + for (auto inst : moduleInst->getGlobalInsts()) + { + if (auto interfaceType = as(inst)) + { + if (auto decor = interfaceType->findDecoration()) + { + if (decor->getName() == "IDifferentiablePtrType") + { + return interfaceType; + } + } + } + } + return nullptr; +} + AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst* inModuleInst) : moduleInst(inModuleInst), targetProgram(target) { @@ -328,14 +379,27 @@ AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + differentialAssocTypeWitnessTableType = findDifferentialTypeWitnessTableType(); zeroMethodStructKey = findZeroMethodStructKey(); + zeroMethodType = cast(getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementVal()); addMethodStructKey = findAddMethodStructKey(); + addMethodType = cast(getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementVal()); mulMethodStructKey = findMulMethodStructKey(); nullDifferentialStructType = findNullDifferentialStructType(); nullDifferentialWitness = findNullDifferentialWitness(); - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; + isInterfaceAvailable = true; + } + + differentiablePtrInterfaceType = as(findDifferentiableRefInterface(inModuleInst)); + + if (differentiablePtrInterfaceType) + { + differentialAssocRefTypeStructKey = findDifferentialPtrTypeStructKey(); + differentialAssocRefTypeWitnessStructKey = findDifferentialPtrTypeWitnessStructKey(); + differentialAssocRefTypeWitnessTableType = findDifferentialPtrTypeWitnessTableType(); + + isPtrInterfaceAvailable = true; } } @@ -404,14 +468,14 @@ IRInst* AutoDiffSharedContext::findNullDifferentialWitness() } -IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +IRInterfaceRequirementEntry* AutoDiffSharedContext::getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index) { - if (as(moduleInst) && differentiableInterfaceType) + if (as(moduleInst) && interface) { // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as(differentiableInterfaceType->getOperand(index))) - return as(entry->getRequirementKey()); + // SLANG_ASSERT(interface->getOperandCount() == 5); + if (auto entry = as(interface->getOperand(index))) + return entry; else { SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); @@ -421,6 +485,50 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde return nullptr; } +// Extracts conformance interface from a witness inst while accounting for some +// quirks in the type system around interfaces that conform to other interfaces. +// +IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWitness(IRInst* witness) +{ + IRInterfaceType* diffInterfaceType = nullptr; + if (auto witnessTableType = as(witness->getDataType())) + { + diffInterfaceType = cast(witnessTableType->getConformanceType()); + } + else if (auto structKey = as(witness)) + { + // We currently assume that a struct key is used uniquely for a single interface-requirement-entry. + // Find that entry + for (IRUse* use = structKey->firstUse; use; use = use->nextUse) + { + if (auto entry = as(use->getUser())) + { + auto innerWitnessTableType = cast(entry->getRequirementVal()); + diffInterfaceType = cast(innerWitnessTableType->getConformanceType()); + break; + } + } + } + else if (auto interfaceRequirementEntry = as(witness)) + { + auto innerWitnessTableType = cast(interfaceRequirementEntry->getRequirementVal()); + diffInterfaceType = cast(innerWitnessTableType->getConformanceType()); + } + else if (auto tupleType = as(witness->getDataType())) + { + SLANG_ASSERT(tupleType->getOperandCount() >= 1); + auto operand = tupleType->getOperand(0); + auto innerWitnessTableType = cast(operand); + return cast(innerWitnessTableType->getConformanceType()); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } + + return diffInterfaceType; +} + void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; @@ -434,7 +542,13 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { if (auto item = as(child)) { - auto existingItem = differentiableWitnessDictionary.tryGetValue(item->getConcreteType()); + IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); + + SLANG_ASSERT( + diffInterfaceType == sharedContext->differentiableInterfaceType + || diffInterfaceType == sharedContext->differentiablePtrInterfaceType); + + auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType()); if (existingItem) { *existingItem = item->getWitness(); @@ -458,20 +572,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { auto element = concreteType->getOperand(i); auto elementWitness = witnessPack->getOperand(i); - differentiableWitnessDictionary.addIfNotExists( - (IRType*)element, - _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + + if (diffInterfaceType == sharedContext->differentiableInterfaceType) + addTypeToDictionary( + (IRType*)element, + elementWitness); + else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType) + addTypeToDictionary( + (IRType*)element, + elementWitness); } return; } } - differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness()); if (!as(item->getConcreteType())) { - differentiableWitnessDictionary.addIfNotExists( - (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey), + addTypeToDictionary( + (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind()), item->getWitness()); } @@ -480,29 +600,55 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) // For differential pair types, register the differential type as well. IRBuilder builder(diffPairType); builder.setInsertAfter(diffPairType->getWitness()); - auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey); - auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey); - if (diffType && diffWitness) - { - differentiableWitnessDictionary.addIfNotExists((IRType*)diffType, diffWitness); - } + + // TODO(sai): lot of this logic is duplicated. need to refactor. + auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey, builder.getTypeKind()) : + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocRefTypeStructKey, builder.getTypeKind()); + auto diffWitness = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType) : + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + + addTypeToDictionary((IRType*)diffType, diffWitness); } } } } } -IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type, DiffConformanceKind kind) { IRInst* foundResult = nullptr; - differentiableWitnessDictionary.tryGetValue(type, foundResult); - return foundResult; + differentiableTypeWitnessDictionary.tryGetValue(type, foundResult); + if (!foundResult) + return nullptr; + + if (kind == DiffConformanceKind::Any) + return foundResult; + + if (auto baseType = getConformanceTypeFromWitness(foundResult)) + { + if (baseType == sharedContext->differentiableInterfaceType && kind == DiffConformanceKind::Value) + return foundResult; + else if (baseType == sharedContext->differentiablePtrInterfaceType && kind == DiffConformanceKind::Ptr) + return foundResult; + } + + return nullptr; } -IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType) { - if (auto conformance = tryGetDifferentiableWitness(builder, origType)) - return _lookupWitness(builder, conformance, key); + if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any)) + return _lookupWitness(builder, conformance, key, resultType); return nullptr; } @@ -514,7 +660,7 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { - return _getDiffTypeFromPairType(sharedContext, builder, type); + return this->differentiateType(builder, type->getValueType()); } IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) @@ -525,20 +671,34 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); } IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey, sharedContext->addMethodType); +} + +void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRInst* witness) +{ + auto conformanceType = getConformanceTypeFromWitness(witness); + + if (!sharedContext->isInterfaceAvailable && !sharedContext->isPtrInterfaceAvailable) + return; + + SLANG_ASSERT( + conformanceType == sharedContext->differentiableInterfaceType || + conformanceType == sharedContext->differentiablePtrInterfaceType); + + differentiableTypeWitnessDictionary.addIfNotExists(type, witness); } IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable) { SLANG_RELEASE_ASSERT(interfaceType); - List lookupKeyPath = findDifferentiableInterfaceLookupPath( + List lookupKeyPath = findInterfaceLookupPath( sharedContext->differentiableInterfaceType, interfaceType); IRInst* differentialTypeWitness = witnessTable; @@ -549,6 +709,7 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface { differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); // Lookup insts are always primal values. + builder->markInstAsPrimal(differentialTypeWitness); } return differentialTypeWitness; @@ -557,10 +718,10 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface return nullptr; } -// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. -static bool _findDifferentiableInterfaceLookupPathImpl( +// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `supType`. +static bool _findInterfaceLookupPathImpl( HashSet& processedTypes, - IRInterfaceType* idiffType, + IRInterfaceType* supType, IRInterfaceType* type, List& currentPath) { @@ -576,13 +737,13 @@ static bool _findDifferentiableInterfaceLookupPathImpl( if (auto wt = as(entry->getRequirementVal())) { currentPath.add(entry); - if (wt->getConformanceType() == idiffType) + if (wt->getConformanceType() == supType) { return true; } else if (auto subInterfaceType = as(wt->getConformanceType())) { - if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) + if (_findInterfaceLookupPathImpl(processedTypes, supType, subInterfaceType, currentPath)) return true; } currentPath.removeLast(); @@ -591,11 +752,11 @@ static bool _findDifferentiableInterfaceLookupPathImpl( return false; } -List DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type) +List DifferentiableTypeConformanceContext::findInterfaceLookupPath(IRInterfaceType *supType, IRInterfaceType *type) { List currentPath; HashSet processedTypes; - _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); + _findInterfaceLookupPathImpl(processedTypes, supType, type, currentPath); return currentPath; } @@ -722,7 +883,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { if (auto pairType = as(globalInst)) { - differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); + addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); } } } @@ -762,9 +923,8 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build case kIROp_DifferentialPairType: { auto primalPairType = as(primalType); - return getOrCreateDiffPairType( - builder, - getDiffTypeFromPairType(builder, primalPairType), + return builder->getDifferentialPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), getDiffTypeWitnessFromPairType(builder, primalPairType)); } @@ -776,6 +936,14 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build getDiffTypeWitnessFromPairType(builder, primalPairType)); } + case kIROp_DifferentialPtrPairType: + { + auto primalPairType = as(primalType); + return builder->getDifferentialPtrPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), + getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + case kIROp_FuncType: { SLANG_UNIMPLEMENTED_X("Impl"); @@ -817,12 +985,12 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build } } -IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) +IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType, DiffConformanceKind kind) { if (isNoDiffType((IRType*)primalType)) return nullptr; - - IRInst* witness = lookUpConformanceForType((IRType*)primalType); + + IRInst* witness = lookUpConformanceForType((IRType*)primalType, kind); if (witness) { SLANG_RELEASE_ASSERT(witness || as(primalType)); @@ -834,31 +1002,60 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil witness = nullptr; } - if (!witness) + if (witness) + return witness; + + // If a witness is not already mapped, build one if possible. + SLANG_RELEASE_ASSERT(primalType); + if (auto primalPairType = as(primalType)) { - SLANG_RELEASE_ASSERT(primalType); - if (auto primalPairType = as(primalType)) - { - witness = getOrCreateDifferentiablePairWitness(builder, primalPairType); - } - else if (auto arrayType = as(primalType)) - { - witness = getArrayWitness(builder, arrayType); - } - else if (auto extractExistential = as(primalType)) - { - witness = getExtractExistensialTypeWitness(builder, extractExistential); - } - else if (auto typePack = as(primalType)) + witness = buildDifferentiablePairWitness(builder, primalPairType, kind); + } + else if (auto arrayType = as(primalType)) + { + witness = buildArrayWitness(builder, arrayType, kind); + } + else if (auto extractExistential = as(primalType)) + { + witness = buildExtractExistensialTypeWitness(builder, extractExistential, kind); + } + else if (auto typePack = as(primalType)) + { + witness = buildTupleWitness(builder, typePack, kind); + } + else if (auto tupleType = as(primalType)) + { + witness = buildTupleWitness(builder, tupleType, kind); + } + else if (auto lookup = as(primalType)) + { + // For types that are lookups from a table, we can simply lookup the witness from the same table + if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey) { - witness = getTupleWitness(builder, typePack); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocTypeWitnessStructKey); } - else if (auto tupleType = as(primalType)) + + if (lookup->getRequirementKey() == sharedContext->differentialAssocRefTypeStructKey) { - witness = getTupleWitness(builder, tupleType); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocRefTypeWitnessStructKey); } } - return witness; + + // If we created a witness, register it. + if (witness) + { + addTypeToDictionary((IRType*)primalType, witness); + return witness; + } + + // Failed. Type is either non-differentiable, or unhandled. + return nullptr; } IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) @@ -868,77 +1065,97 @@ IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* witness); } -IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType) +IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( + IRBuilder* builder, + IRDifferentialPairTypeBase* pairType, + DiffConformanceKind target) { - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); - - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - bool isUserCodeType = as(pairType) ? true : false; - - // Fill in differential method implementations. - auto elementType = as(pairType)->getValueType(); - auto innerWitness = as(pairType)->getWitness(); - - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); - b.emitBlock(); - auto p0 = b.emitParam(diffDiffPairType); - auto p1 = b.emitParam(diffDiffPairType); - - // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - IRInst* argsPrimal[2] = { - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; - auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); - IRInst* argsDiff[2] = { - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; - auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) - : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); - b.emitReturn(retVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); - b.emitBlock(); - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) - : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); - b.emitReturn(retVal); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + bool isUserCodeType = as(pairType) ? true : false; + + // Fill in differential method implementations. + auto elementType = as(pairType)->getValueType(); + auto innerWitness = as(pairType)->getWitness(); + + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); + b.emitBlock(); + auto p0 = b.emitParam(diffDiffPairType); + auto p1 = b.emitParam(diffDiffPairType); + + // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + IRInst* argsPrimal[2] = { + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; + auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); + IRInst* argsDiff[2] = { + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; + auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) + : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); + b.emitReturn(retVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); + b.emitBlock(); + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) + : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + table = builder->createWitnessTable( + sharedContext->differentiablePtrInterfaceType, + (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } - - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)pairType] = table; return table; } -IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType) +IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( + IRBuilder* builder, + IRArrayType* arrayType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType); @@ -946,70 +1163,89 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder if (!diffArrayType) return nullptr; - auto innerWitness = tryGetDifferentiableWitness(builder, as(arrayType)->getElementType()); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType)); + auto innerWitness = tryGetDifferentiableWitness(builder, as(arrayType)->getElementType(), DiffConformanceKind::Value); - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - auto elementType = as(diffArrayType)->getElementType(); + auto elementType = as(diffArrayType)->getElementType(); - // Fill in differential method implementations. + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffArrayType, diffArrayType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); + b.emitBlock(); + auto p0 = b.emitParam(diffArrayType); + auto p1 = b.emitParam(diffArrayType); + + // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto resultVar = b.emitVar(diffArrayType); + IRBlock* loopBodyBlock = nullptr; + IRBlock* loopBreakBlock = nullptr; + auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); + b.setInsertBefore(loopBodyBlock->getTerminator()); + + IRInst* args[2] = { + b.emitElementExtract(p0, loopCounter), + b.emitElementExtract(p1, loopCounter) }; + auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); + auto addr = b.emitElementAddress(resultVar, loopCounter); + b.emitStore(addr, elementResult); + b.setInsertInto(loopBreakBlock); + b.emitReturn(b.emitLoad(resultVar)); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); + b.emitBlock(); + + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffArrayType, diffArrayType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); - b.emitBlock(); - auto p0 = b.emitParam(diffArrayType); - auto p1 = b.emitParam(diffArrayType); + SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType)); - // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto resultVar = b.emitVar(diffArrayType); - IRBlock* loopBodyBlock = nullptr; - IRBlock* loopBreakBlock = nullptr; - auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); - b.setInsertBefore(loopBodyBlock->getTerminator()); + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)arrayType); - IRInst* args[2] = { - b.emitElementExtract(p0, loopCounter), - b.emitElementExtract(p1, loopCounter) }; - auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); - auto addr = b.emitElementAddress(resultVar, loopCounter); - b.emitStore(addr, elementResult); - b.setInsertInto(loopBreakBlock); - b.emitReturn(b.emitLoad(resultVar)); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } + else { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); - b.emitBlock(); - - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); - b.emitReturn(retVal); + SLANG_UNEXPECTED("Invalid conformance kind for synthesis"); } - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)arrayType] = table; - return table; } -IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType) +IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( + IRBuilder* builder, + IRInst* inTupleType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType); @@ -1017,100 +1253,116 @@ IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder if (!diffTupleType) return nullptr; - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - // Fill in differential method implementations. - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffTupleType, diffTupleType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); - b.emitBlock(); - auto p0 = b.emitParam(diffTupleType); - auto p1 = b.emitParam(diffTupleType); - List results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType)); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffTupleType, diffTupleType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); + b.emitBlock(); + auto p0 = b.emitParam(diffTupleType); + auto p1 = b.emitParam(diffTupleType); + List results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto iVal = b.getIntValue(b.getIntType(), i); + IRInst* args[2] = { + b.emitGetTupleElement(diffElementType, p0, iVal), + b.emitGetTupleElement(diffElementType, p1, iVal) }; + elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto iVal = b.getIntValue(b.getIntType(), i); - IRInst* args[2] = { - b.emitGetTupleElement(diffElementType, p0, iVal), - b.emitGetTupleElement(diffElementType, p1, iVal) }; - elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); - b.emitBlock(); - List results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.emitBlock(); + List results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); } + else if (target == DiffConformanceKind::Ptr) + { + SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType)); - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)inTupleType] = table; + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); + } return table; } -IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( +IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness( IRBuilder* builder, - IRExtractExistentialType* extractExistentialType) + IRExtractExistentialType* extractExistentialType, + DiffConformanceKind target) { + SLANG_UNUSED(target); // logic is the same for both value and ptr + // Check that the type's base is differentiable if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType())) { @@ -1203,6 +1455,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_AutoDiffOriginalValueDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_IntermediateContextFieldDifferentialTypeDecoration: + case kIROp_CheckpointIntermediateDecoration: decor->removeAndDeallocate(); break; case kIROp_AutoDiffBuiltinDecoration: @@ -1309,12 +1562,13 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* if (context.isDifferentiableType((IRType*)typeInst)) return true; + // Look for equivalent types. - for (auto type : context.differentiableWitnessDictionary) + for (auto type : context.differentiableTypeWitnessDictionary) { if (isTypeEqual(type.key, (IRType*)typeInst)) { - context.differentiableWitnessDictionary[(IRType*)typeInst] = type.value; + context.differentiableTypeWitnessDictionary[(IRType*)typeInst] = type.value; return true; } } @@ -1671,7 +1925,7 @@ struct AutoDiffPass : public InstPassBase IRBuilder keyBuilder = builder; keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); auto diffKey = keyBuilder.createStructKey(); - auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey); + auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey, builder.getTypeKind()); info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); info.witness = diffFieldWitness; builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); @@ -1694,7 +1948,11 @@ struct AutoDiffPass : public InstPassBase List fieldVals; for (auto info : diffFields) { - auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey); + auto innerZeroMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->zeroMethodStructKey, + autodiffContext->zeroMethodType); IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr); fieldVals.add(val); } @@ -1718,7 +1976,11 @@ struct AutoDiffPass : public InstPassBase List fieldVals; for (auto info : diffFields) { - auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey); + auto innerAddMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->addMethodStructKey, + autodiffContext->addMethodType); IRInst* args[2] = { builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()), builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()), diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 812471fe3d..ad2486aad4 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -57,6 +57,14 @@ struct DiffTranscriberSet AutoDiffTranscriberBase* backwardTranscriber = nullptr; }; + +enum class DiffConformanceKind +{ + Any = 0, // Perform actions for any conformance (infer from context) + Ptr = 1, // Perform actions for IDifferentiablePtrType + Value = 2 // Perform actions for IDifferentiable +}; + struct AutoDiffSharedContext { TargetProgram* targetProgram = nullptr; @@ -78,6 +86,7 @@ struct AutoDiffSharedContext // The struct key for the witness that `Differential` associated type conforms to // `IDifferential`. IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + IRWitnessTableType* differentialAssocTypeWitnessTableType = nullptr; // The struct key for the 'zero()' associated type @@ -85,12 +94,14 @@ struct AutoDiffSharedContext // implementation of zero() for a given type. // IRStructKey* zeroMethodStructKey = nullptr; + IRFuncType* zeroMethodType = nullptr; // The struct key for the 'add()' associated type // defined inside IDifferential. We use this to lookup the // implementation of add() for a given type. // IRStructKey* addMethodStructKey = nullptr; + IRFuncType* addMethodType = nullptr; IRStructKey* mulMethodStructKey = nullptr; @@ -104,12 +115,27 @@ struct AutoDiffSharedContext // IRInst* nullDifferentialWitness = nullptr; + + // A reference to the builtin IDifferentiablePtrType interface type. + IRInterfaceType* differentiablePtrInterfaceType = nullptr; + + // The struct key for the 'Differential' associated type + // defined inside IDifferentialPtrType. We use this to lookup the differential + // type in the conformance table associated with the concrete type. + // + IRStructKey* differentialAssocRefTypeStructKey = nullptr; + + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferentialPtrType`. + IRStructKey* differentialAssocRefTypeWitnessStructKey = nullptr; + IRWitnessTableType* differentialAssocRefTypeWitnessTableType = nullptr; // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. // Set to false to indicate that we are uninitialized. // bool isInterfaceAvailable = false; + bool isPtrInterfaceAvailable = false; List followUpFunctionsToTranscribe; @@ -127,38 +153,70 @@ struct AutoDiffSharedContext IRStructKey* findDifferentialTypeStructKey() { - return getIDifferentiableStructKeyAtIndex(0); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 0)->getRequirementKey()); } IRStructKey* findDifferentialTypeWitnessStructKey() { - return getIDifferentiableStructKeyAtIndex(1); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementKey()); + } + + IRWitnessTableType* findDifferentialTypeWitnessTableType() + { + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementVal()); } IRStructKey* findZeroMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(2); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementKey()); } IRStructKey* findAddMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(3); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementKey()); } IRStructKey* findMulMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(4); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 4)->getRequirementKey()); + } + + + IRStructKey* findDifferentialPtrTypeStructKey() + { + return cast( + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 0)->getRequirementKey()); } - IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); + IRStructKey* findDifferentialPtrTypeWitnessStructKey() + { + return cast( + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementKey()); + } + + IRWitnessTableType* findDifferentialPtrTypeWitnessTableType() + { + return cast( + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementVal()); + } + + //IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); + IRInterfaceRequirementEntry* getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index); }; + struct DifferentiableTypeConformanceContext { AutoDiffSharedContext* sharedContext; IRGlobalValueWithCode* parentFunc = nullptr; - OrderedDictionary differentiableWitnessDictionary; + OrderedDictionary differentiableTypeWitnessDictionary; IRFunc* existentialDAddFunc = nullptr; @@ -167,7 +225,7 @@ struct DifferentiableTypeConformanceContext { // Populate dictionary with null differential type. if (sharedContext->nullDifferentialStructType) - differentiableWitnessDictionary.add( + differentiableTypeWitnessDictionary.add( sharedContext->nullDifferentialStructType, sharedContext->nullDifferentialWitness); } @@ -179,21 +237,13 @@ struct DifferentiableTypeConformanceContext // Lookup a witness table for the concreteType. One should exist if concreteType // inherits (successfully) from IDifferentiable. // - IRInst* lookUpConformanceForType(IRInst* type); + IRInst* lookUpConformanceForType(IRInst* type, DiffConformanceKind kind); - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType = nullptr); IRType* differentiateType(IRBuilder* builder, IRInst* primalType); - IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); - - IRInst* getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType); - - IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType); - - IRInst* getTupleWitness(IRBuilder* builder, IRInst* tupleType); - - IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType); + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); @@ -207,17 +257,21 @@ struct DifferentiableTypeConformanceContext IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + void addTypeToDictionary(IRType* type, IRInst* witness); + + IRInterfaceType* getConformanceTypeFromWitness(IRInst* witness); + IRInst* tryExtractConformanceFromInterfaceType( IRBuilder* builder, IRInterfaceType* interfaceType, IRWitnessTable* witnessTable); - List findDifferentiableInterfaceLookupPath( - IRInterfaceType* idiffType, + List findInterfaceLookupPath( + IRInterfaceType* supType, IRInterfaceType* type); // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. + // in order to conform to the IDifferentiable/IDifferentiablePtrType interfaces // Note that inside a generic block, this will be a witness table lookup instruction // that gets resolved during the specialization pass. // @@ -227,8 +281,10 @@ struct DifferentiableTypeConformanceContext { case kIROp_InterfaceType: { - if (isDifferentiableType(origType)) + if (isDifferentiableValueType(origType)) return this->sharedContext->differentiableInterfaceType; + else if (isDifferentiablePtrType(origType)) + return this->sharedContext->differentiablePtrInterfaceType; else return nullptr; } @@ -254,12 +310,29 @@ struct DifferentiableTypeConformanceContext auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType); return builder->getDifferentialPairUserCodeType((IRType*)diffType, diffWitness); } + case kIROp_DifferentialPtrPairType: + { + auto diffPairType = as(origType); + auto diffType = getDiffTypeFromPairType(builder, diffPairType); + auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType); + return builder->getDifferentialPtrPairType((IRType*)diffType, diffWitness); + } default: - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); + if (isDifferentiableValueType(origType)) + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey, builder->getTypeKind()); + else if (isDifferentiablePtrType(origType)) + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocRefTypeStructKey, builder->getTypeKind()); + else + return nullptr; } } bool isDifferentiableType(IRType* origType) + { + return isDifferentiableValueType(origType) || isDifferentiablePtrType(origType); + } + + bool isDifferentiableValueType(IRType* origType) { for (; origType;) { @@ -279,7 +352,27 @@ struct DifferentiableTypeConformanceContext origType = (IRType*)origType->getOperand(0); continue; default: - return lookUpConformanceForType(origType) != nullptr; + return lookUpConformanceForType(origType, DiffConformanceKind::Value) != nullptr; + } + } + return false; + } + + bool isDifferentiablePtrType(IRType* origType) + { + for (; origType;) + { + switch (origType->getOp()) + { + case kIROp_VectorType: + case kIROp_ArrayType: + case kIROp_PtrType: + case kIROp_OutType: + case kIROp_InOutType: + origType = (IRType*)origType->getOperand(0); + continue; + default: + return lookUpConformanceForType(origType, DiffConformanceKind::Ptr) != nullptr; } } return false; @@ -287,13 +380,13 @@ struct DifferentiableTypeConformanceContext IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) { - auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); return result; } IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) { - auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey, sharedContext->addMethodType); return result; } @@ -307,8 +400,28 @@ struct DifferentiableTypeConformanceContext IRFunc* getOrCreateExistentialDAddMethod(); + IRInst* buildDifferentiablePairWitness( + IRBuilder* builder, + IRDifferentialPairTypeBase* pairType, + DiffConformanceKind target); + + IRInst* buildArrayWitness( + IRBuilder* builder, + IRArrayType* pairType, + DiffConformanceKind target); + + IRInst* buildTupleWitness( + IRBuilder* builder, + IRInst* tupleType, + DiffConformanceKind target); + + IRInst* buildExtractExistensialTypeWitness( + IRBuilder* builder, + IRExtractExistentialType* extractExistentialType, + DiffConformanceKind target); }; + struct DifferentialPairTypeBuilder { DifferentialPairTypeBuilder() = default; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 8b4886a2cf..cae47fffde 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -625,7 +625,7 @@ struct CheckDifferentiabilityPassContext : public InstPassBase } } - if (!sharedContext.isInterfaceAvailable) + if (!sharedContext.isInterfaceAvailable && !sharedContext.isPtrInterfaceAvailable) return; for (auto inst : module->getGlobalInsts()) diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index e2297bcb2c..a8b9b548e0 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -220,6 +220,7 @@ static void _cloneInstDecorationsAndChildren( auto oldType = oldParam->getFullType(); auto newType = (IRType*)findCloneForOperand(env, oldType); newParam->setFullType(newType); + newParam->sourceLoc = oldParam->sourceLoc; } } diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp index b17fad6ec0..0db2fc765c 100644 --- a/source/slang/slang-ir-eliminate-phis.cpp +++ b/source/slang/slang-ir-eliminate-phis.cpp @@ -462,6 +462,7 @@ struct PhiEliminationContext // to the temporary that will replace it. // param->transferDecorationsTo(temp); + temp->sourceLoc = param->sourceLoc; } // The other main auxilliary sxtructure is used to track @@ -550,6 +551,7 @@ struct PhiEliminationContext auto user = use->getUser(); m_builder.setInsertBefore(user); auto newVal = m_builder.emitLoad(temp); + newVal->sourceLoc = param->sourceLoc; m_builder.replaceOperand(use, newVal); } @@ -938,6 +940,7 @@ struct PhiEliminationContext newOperands.getCount(), newOperands.getArrayView().getBuffer()); oldBranch->transferDecorationsTo(newBranch); + newBranch->sourceLoc = oldBranch->sourceLoc; // TODO: We could consider just modifying `branch` in-place by clearing // the relevant operands for the phi arguments and setting its operand diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp index 34a0e5ff4c..fa556bc58e 100644 --- a/source/slang/slang-ir-init-local-var.cpp +++ b/source/slang/slang-ir-init-local-var.cpp @@ -47,6 +47,9 @@ void initializeLocalVariables(IRModule* module, IRGlobalValueWithCode* func) breakLabel:; if (initialized) continue; + + IRBuilderSourceLocRAII sourceLocationScope(&builder, inst->sourceLoc); + builder.setInsertAfter(inst); builder.emitStore( inst, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b526df3a92..0d689660e9 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -61,7 +61,8 @@ INST(Nop, nop, 0, 0) INST(DifferentialPairType, DiffPair, 1, HOISTABLE) INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE) - INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPairUserCodeType) + INST(DifferentialPtrPairType, DiffRefPair, 1, HOISTABLE) + INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPtrPairType) INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) @@ -325,15 +326,18 @@ INST(DefaultConstruct, defaultConstruct, 0, 0) INST(MakeDifferentialPair, MakeDiffPair, 2, 0) INST(MakeDifferentialPairUserCode, MakeDiffPairUserCode, 2, 0) -INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPairUserCode) +INST(MakeDifferentialPtrPair, MakeDiffRefPair, 2, 0) +INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPtrPair) INST(DifferentialPairGetDifferential, GetDifferential, 1, 0) INST(DifferentialPairGetDifferentialUserCode, GetDifferentialUserCode, 1, 0) -INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPairGetDifferentialUserCode) +INST(DifferentialPtrPairGetDifferential, GetDifferentialPtr, 1, 0) +INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPtrPairGetDifferential) INST(DifferentialPairGetPrimal, GetPrimal, 1, 0) INST(DifferentialPairGetPrimalUserCode, GetPrimalUserCode, 1, 0) -INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPairGetPrimalUserCode) +INST(DifferentialPtrPairGetPrimal, GetPrimalRef, 1, 0) +INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPtrPairGetPrimal) INST(Specialize, specialize, 2, HOISTABLE) INST(LookupWitness, lookupWitness, 2, HOISTABLE) @@ -1056,6 +1060,9 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) /// Hint that the result from a call to the decorated function should be recomputed in backward prop function. INST(PreferRecomputeDecoration, PreferRecomputeDecoration, 0, 0) + /// Hint that a struct is used for reverse mode checkpointing + INST(CheckpointIntermediateDecoration, CheckpointIntermediateDecoration, 1, 0) + INST_RANGE(CheckpointHintDecoration, PreferCheckpointDecoration, PreferRecomputeDecoration) /// Marks a function whose return value is never dynamic uniform. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f240e9ad8c..19386d4dde 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -947,6 +947,16 @@ struct IRPreferCheckpointDecoration : IRCheckpointHintDecoration IR_LEAF_ISA(PreferCheckpointDecoration) }; +struct IRCheckpointIntermediateDecoration : IRCheckpointHintDecoration +{ + enum + { + kOp = kIROp_CheckpointIntermediateDecoration + }; + IR_LEAF_ISA(CheckpointIntermediateDecoration) + + IRInst* getSourceFunction() { return getOperand(0); } +}; struct IRLoopCounterDecoration : IRDecoration { @@ -2949,6 +2959,10 @@ struct IRMakeDifferentialPairUserCode : IRMakeDifferentialPairBase { IR_LEAF_ISA(MakeDifferentialPairUserCode) }; +struct IRMakeDifferentialPtrPair : IRMakeDifferentialPairBase +{ + IR_LEAF_ISA(MakeDifferentialPtrPair) +}; struct IRDifferentialPairGetDifferentialBase : IRInst { @@ -2963,6 +2977,10 @@ struct IRDifferentialPairGetDifferentialUserCode : IRDifferentialPairGetDifferen { IR_LEAF_ISA(DifferentialPairGetDifferentialUserCode) }; +struct IRDifferentialPtrPairGetDifferential : IRDifferentialPairGetDifferentialBase +{ + IR_LEAF_ISA(DifferentialPtrPairGetDifferential) +}; struct IRDifferentialPairGetPrimalBase : IRInst { @@ -2977,6 +2995,10 @@ struct IRDifferentialPairGetPrimalUserCode : IRDifferentialPairGetPrimalBase { IR_LEAF_ISA(DifferentialPairGetPrimalUserCode) }; +struct IRDifferentialPtrPairGetPrimal : IRDifferentialPairGetPrimalBase +{ + IR_LEAF_ISA(DifferentialPtrPairGetPrimal) +}; struct IRDetachDerivative : IRInst { @@ -3647,6 +3669,10 @@ struct IRBuilder IRDifferentialPairType* getDifferentialPairType( IRType* valueType, IRInst* witnessTable); + + IRDifferentialPtrPairType* getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable); IRDifferentialPairUserCodeType* getDifferentialPairUserCodeType( IRType* valueType, @@ -3787,6 +3813,8 @@ struct IRBuilder IRInst* emitGetTorchCudaStream(); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential); IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); @@ -3970,9 +3998,19 @@ struct IRBuilder IRInst* emitGetOptionalValue(IRInst* optValue); IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value); IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); + IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialValuePairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair); IRInst* emitMakeVector( @@ -5153,6 +5191,11 @@ struct IRBuilder { addDecoration(inst, kIROp_MemoryQualifierSetDecoration, getIntValue(getIntType(), flags)); } + + void addCheckpointIntermediateDecoration(IRInst* inst, IRGlobalValueWithCode *func) + { + addDecoration(inst, kIROp_CheckpointIntermediateDecoration, func); + } }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 753c930a86..ef05511612 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -526,6 +526,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) // we will now introduce a breakable region for each iteration. IRBuilder builder(module); + IRBuilderSourceLocRAII sourceLocationScope(&builder, loopInst->sourceLoc); auto targetBlock = loopInst->getTargetBlock(); diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index e44c4079b4..506e6a3350 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -431,6 +431,7 @@ PhiInfo* addPhi( RefPtr phiInfo = new PhiInfo(); context->phiInfos.add(phi, phiInfo); + phi->sourceLoc = var->sourceLoc; phiInfo->phi = phi; phiInfo->var = var; diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp index b5ce05895f..744bedf615 100644 --- a/source/slang/slang-ir-use-uninitialized-values.cpp +++ b/source/slang/slang-ir-use-uninitialized-values.cpp @@ -452,21 +452,6 @@ namespace Slang return false; } - static bool isWrittenTo(IRInst* inst) - { - for (auto alias : getAliasableInstructions(inst)) - { - for (auto use = alias->firstUse; use; use = use->nextUse) - { - InstructionUsageType usage = getInstructionUsageType(use->getUser(), alias); - if (usage == Store || usage == StoreParent) - return true; - } - } - - return false; - } - static bool isDirectlyWrittenTo(IRInst* inst) { for (auto use = inst->firstUse; use; use = use->nextUse) @@ -584,36 +569,6 @@ namespace Slang } } - static void checkParameterAsInOut(IRParam* param, IRFunc* func, bool isThis, DiagnosticSink* sink) - { - // If the inout is used for the sake of interface conformance, let it be - for (auto use = func->firstUse; use; use = use->nextUse) - { - if (as(use->getUser())) - return; - } - - // If there is at least one write... - if (isWrittenTo(param)) - return; - - // ...or if there is an intrinsic_asm instruction - for (const auto& b : func->getBlocks()) - { - for (auto inst = b->getFirstInst(); inst; inst = inst->next) - { - if (as(inst)) - return; - } - } - - sink->diagnose(param, - isThis - ? Diagnostics::methodNeverMutates - : Diagnostics::inOutNeverStoredInto, - param); - } - static void checkUninitializedValues(IRFunc* func, DiagnosticSink* sink) { // Differentiable functions will generate undefined values @@ -635,22 +590,15 @@ namespace Slang if (auto entry = func->findDecoration()) stage = entry->getProfile().getStage(); - bool structMethod = func->findDecoration(); - // Check out parameters if (!isUnmodifying(func)) { int index = 0; for (auto param : firstBlock->getParams()) { - bool isThis = structMethod && (index == 0); - ParameterCheckType checkType = isPotentiallyUnintended(param, stage, index); if (checkType == AsOut) checkParameterAsOut(reachability, func, param, sink); - else if (checkType == AsInOut) - checkParameterAsInOut(param, func, isThis, sink); - index++; } } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 633451e1e3..7cc8e0697a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3022,6 +3022,17 @@ namespace Slang operands); } + IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPtrPairType*)getType( + kIROp_DifferentialPtrPairType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType( IRType* valueType, IRInst* witnessTable) @@ -3503,7 +3514,7 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) + IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as(type)); SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); @@ -3512,8 +3523,101 @@ namespace Slang auto inst = createInstWithTrailingArgs( this, kIROp_MakeDifferentialPair, type, 2, args); addInst(inst); + inst->sourceLoc = primal->sourceLoc; + return inst; + } + + IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential) + { + SLANG_RELEASE_ASSERT(as(type)); + SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); + + IRInst* args[] = {primal, differential}; + auto inst = createInstWithTrailingArgs( + this, kIROp_MakeDifferentialPtrPair, type, 2, args); + addInst(inst); return inst; } + + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal) + { + if (as(pairType)) + { + return emitMakeDifferentialValuePair(pairType, primalVal, diffVal); + } + else if (as(pairType)) + { + // Quick optimization: + // If primalVal and diffVal are extracted from the same pointer-pair, + // we can just use the pointer-pair directly. + // + if (auto primalPtrVal = as(primalVal)) + { + if (auto diffPtrVal = as(diffVal)) + { + if (primalPtrVal->getBase() == diffPtrVal->getBase()) + return primalPtrVal->getBase(); + } + } + return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetDifferential(diffType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetDifferential(diffType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(primalType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(primalType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) { @@ -3524,6 +3628,7 @@ namespace Slang auto inst = createInstWithTrailingArgs( this, kIROp_MakeDifferentialPairUserCode, type, 2, args); addInst(inst); + inst->sourceLoc = primal->sourceLoc; return inst; } @@ -4237,7 +4342,7 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args); } - IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as(diffPair->getDataType())); return emitIntrinsicInst( @@ -4247,7 +4352,18 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + + IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + SLANG_ASSERT(as(diffPair->getDataType())); + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPtrPairGetDifferential, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair) { auto valueType = cast(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( @@ -4257,7 +4373,7 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair) { return emitIntrinsicInst( primalType, @@ -4266,6 +4382,25 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair) + { + auto valueType = cast(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair) + { + return emitIntrinsicInst( + primalType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as(diffPair->getDataType())); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 375107d1d4..14dde200f4 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1660,6 +1660,11 @@ struct IRDifferentialPairType : IRDifferentialPairTypeBase IR_LEAF_ISA(DifferentialPairType) }; +struct IRDifferentialPtrPairType : IRDifferentialPairTypeBase +{ + IR_LEAF_ISA(DifferentialPtrPairType) +}; + struct IRDifferentialPairUserCodeType : IRDifferentialPairTypeBase { IR_LEAF_ISA(DifferentialPairUserCodeType) diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index c02a009570..b9a12f971e 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -339,6 +339,7 @@ void initCommandOptions(CommandOptions& options) { OptionKind::InputFilesRemain, "--", nullptr, "Treat the rest of the command line as input files."}, { OptionKind::ReportDownstreamTime, "-report-downstream-time", nullptr, "Reports the time spent in the downstream compiler." }, { OptionKind::ReportPerfBenchmark, "-report-perf-benchmark", nullptr, "Reports compiler performance benchmark results." }, + { OptionKind::ReportCheckpointIntermediates, "-report-checkpoint-intermediates", nullptr, "Reports information about checkpoint contexts used for reverse-mode automatic differentiation." }, { OptionKind::SkipSPIRVValidation, "-skip-spirv-validation", nullptr, "Skips spirv validation." }, { OptionKind::SourceEmbedStyle, "-source-embed-style", "-source-embed-style ", "If source embedding is enabled, defines the style used. When enabled (with any style other than `none`), " @@ -1703,6 +1704,7 @@ SlangResult OptionsParser::_parse( case OptionKind::DumpReproOnError: case OptionKind::ReportDownstreamTime: case OptionKind::ReportPerfBenchmark: + case OptionKind::ReportCheckpointIntermediates: case OptionKind::SkipSPIRVValidation: case OptionKind::DisableSpecialization: case OptionKind::DisableDynamicDispatch: diff --git a/tests/autodiff/diff-ptr-type-call.slang b/tests/autodiff/diff-ptr-type-call.slang new file mode 100644 index 0000000000..258a4477b5 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-call.slang @@ -0,0 +1,57 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + +// ------ +[Differentiable] +float reduce(MyPtrType a) +{ + return load(a, 0) + load(a, 1); +} + +[Differentiable] +float test(MyPtrType b) +{ + return reduce(b); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} \ No newline at end of file diff --git a/tests/autodiff/diff-ptr-type-loop.slang b/tests/autodiff/diff-ptr-type-loop.slang new file mode 100644 index 0000000000..a57c69b760 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-loop.slang @@ -0,0 +1,65 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + + +// ------ +[Differentiable] +float reduce(MyPtrType a, uint num) +{ + float sum = 0; + [MaxIters(3)] + for (uint i = 0; i < num; i++) + { + sum += load(a, i); + } + + return sum; +} + +[Differentiable] +float test(MyPtrType b, uint num) +{ + return reduce(b, num); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, 2, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} \ No newline at end of file diff --git a/tests/autodiff/diff-ptr-type-smoke.slang b/tests/autodiff/diff-ptr-type-smoke.slang new file mode 100644 index 0000000000..e7e03c5e37 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-smoke.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + +[BackwardDifferentiable] +float test(MyPtrType b, uint idx) +{ + return load(b, idx) + load(b, idx + 1); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, id, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} \ No newline at end of file diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang index 5172970135..3d6e9e702f 100644 --- a/tests/autodiff/reverse-checkpoint-1.slang +++ b/tests/autodiff/reverse-checkpoint-1.slang @@ -2,6 +2,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj //TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -16,13 +17,16 @@ float g(float x) return log(x); } +//CHK: note: checkpointing context of 4 bytes associated with function: 'f' [BackwardDifferentiable] float f(int p, float x) { float y = 1.0; // Test that phi parameter can be restored. if (p == 0) + //CHK: note: 4 bytes (float) used to checkpoint the following item: y = g(x); + return y * y; } @@ -41,3 +45,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) __bwd_diff(f)(0, dpa, 1.0f); outputBuffer[0] = dpa.d; // Expect: 1 } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-checkpoint-2.slang b/tests/autodiff/reverse-checkpoint-2.slang index 8a7262aa4d..1dd3f29638 100644 --- a/tests/autodiff/reverse-checkpoint-2.slang +++ b/tests/autodiff/reverse-checkpoint-2.slang @@ -41,3 +41,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) __bwd_diff(f)(0, dpa, 1.0f); outputBuffer[0] = dpa.d; // Expect: 1 } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 0f95026734..0b6e56f783 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -8,11 +9,14 @@ RWStructuredBuffer outputBuffer; typedef DifferentialPair dpfloat; typedef float.Differential dfloat; +//CHK: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue' [BackwardDifferentiable] float test_loop_with_continue(float y) { + //CHK: note: 20 bytes (FixedArray ) used to checkpoint the following item: float t = y; + //CHK: note: 4 bytes (int32_t) used for a loop counter here: for (int i = 0; i < 3; i++) { if (t > 4.0) @@ -41,3 +45,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[1] = dpa.d; // Expect: 0.0131072 } } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-control-flow-1.slang b/tests/autodiff/reverse-control-flow-1.slang index 7d2f518be9..334de4137e 100644 --- a/tests/autodiff/reverse-control-flow-1.slang +++ b/tests/autodiff/reverse-control-flow-1.slang @@ -1,5 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -40,3 +41,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[1] = dpa.d; // Expect: 1.0 } } + +//CHK: (0): note: no checkpoint contexts to report \ No newline at end of file diff --git a/tests/autodiff/reverse-control-flow-2.slang b/tests/autodiff/reverse-control-flow-2.slang index cde707b4d3..c3790367cf 100644 --- a/tests/autodiff/reverse-control-flow-2.slang +++ b/tests/autodiff/reverse-control-flow-2.slang @@ -1,5 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -73,3 +74,5 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[1] = dpx.d; } } + +//CHK: (0): note: no checkpoint contexts to report \ No newline at end of file diff --git a/tests/autodiff/reverse-control-flow-3.slang b/tests/autodiff/reverse-control-flow-3.slang index 01b5332793..b4fa68e3a3 100644 --- a/tests/autodiff/reverse-control-flow-3.slang +++ b/tests/autodiff/reverse-control-flow-3.slang @@ -1,4 +1,5 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer @@ -75,7 +76,8 @@ void d_getParam(uint id, MaterialParam.Differential diff) outputBuffer[id] += diff.roughness; } - +//CHK-DAG: note: checkpointing context of 8 bytes associated with function: 'updatePathThroughput' +//CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item: [BackwardDifferentiable] void updatePathThroughput(inout PathResult path, const float weight) { @@ -122,9 +124,13 @@ bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, ino \param[in,out] path The path state. \return True if a ray was generated, false otherwise. */ + +//CHK-DAG: note: checkpointing context of 16 bytes associated with function: 'generateScatterRay' [BackwardDifferentiable] bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes) { + //CHK-DAG: note: 8 bytes (s_bwd_prop_updatePathThroughput_Intermediates_0) used to checkpoint the following item: + //CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item: updatePathThroughput(pathRes, bs.val); return true; } @@ -215,5 +221,6 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) var dpx = diffPair(pathRes, pathResD); __bwd_diff(tracePath)(1, dpx); // Expect: 5.0 in outputBuffer[3] } - } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-loop-checkpoint-test.slang b/tests/autodiff/reverse-loop-checkpoint-test.slang index fc206e1289..68ad823ac2 100644 --- a/tests/autodiff/reverse-loop-checkpoint-test.slang +++ b/tests/autodiff/reverse-loop-checkpoint-test.slang @@ -1,5 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -44,13 +45,18 @@ float3 infinitesimal(float3 x) return x - detach(x); } +//CHK: note: checkpointing context of 20 bytes associated with function: 'computeLoop' [BackwardDifferentiable] [PreferRecompute] float3 computeLoop(float y) { + //CHK: note: 4 bytes (float) used to checkpoint the following item: float w = 0; + + //CHK: note: 12 bytes (Vector ) used to checkpoint the following item: float3 w3 = float3(0, 0, 0); + //CHK: note: 4 bytes (int32_t) used for a loop counter here: for (int i = 0; i < 8; i++) { float k = compute(i, y); @@ -93,3 +99,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[2] = computeLoop(1.0).x; } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang index a2c826be98..2ba8535bee 100644 --- a/tests/autodiff/reverse-loop.slang +++ b/tests/autodiff/reverse-loop.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -8,11 +9,14 @@ RWStructuredBuffer outputBuffer; typedef DifferentialPair dpfloat; typedef float.Differential dfloat; +//CHK: note: checkpointing context of 24 bytes associated with function: 'test_simple_loop' [Differentiable] float test_simple_loop(float y) { + //CHK: note: 20 bytes (FixedArray ) used to checkpoint the following item: float t = y; + //CHK: note: 4 bytes (int32_t) used for a loop counter here: for (int i = 0; i < 3; i++) { t = t * t; @@ -38,3 +42,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[1] = dpa.d; // Expect: 0.0131072 } } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/autodiff/reverse-nested-calls.slang b/tests/autodiff/reverse-nested-calls.slang index caf2df6f8c..3c1a52c21c 100644 --- a/tests/autodiff/reverse-nested-calls.slang +++ b/tests/autodiff/reverse-nested-calls.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -15,9 +16,11 @@ float g(float y) return result * result; } +//CHK: note: checkpointing context of 4 bytes associated with function: 'f' [BackwardDifferentiable] float f(float x) { + //CHK: note: 4 bytes (float) used to checkpoint the following item: return 3.0f * g(2.0f * x); } @@ -29,3 +32,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) __bwd_diff(f)(dpa, 1.0f); outputBuffer[0] = dpa.d; // Expect: 96.0 } + +//CHK-NOT: note \ No newline at end of file diff --git a/tests/bugs/overload-ambiguous-1.slang b/tests/bugs/overload-ambiguous-1.slang new file mode 100644 index 0000000000..9f9c6e5bc5 --- /dev/null +++ b/tests/bugs/overload-ambiguous-1.slang @@ -0,0 +1,65 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +namespace A1 +{ + uint func() + { + return 1u; + } + + namespace A2 + { + uint func() + { + return 2u; + } + + namespace A3 + { + uint func() + { + return 3u; + } + + uint test2() + { + return func(); // choose A3::func() + } + } + + namespace A4 + { + uint test() + { + return func(); // choose A2::func() + } + } + } +} + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + using namespace A1; + using namespace A1::A2; + using namespace A1::A2::A3; + using namespace A1::A2::A4; + outputBuffer[0] = test(); + // BUF: 2 + + outputBuffer[1] = func(); // choose the A1::func() + // BUF-NEXT: 1 + + outputBuffer[2] = test2(); + // BUF-NEXT: 3 +} diff --git a/tests/bugs/overload-ambiguous-2.slang b/tests/bugs/overload-ambiguous-2.slang new file mode 100644 index 0000000000..46af9f0919 --- /dev/null +++ b/tests/bugs/overload-ambiguous-2.slang @@ -0,0 +1,67 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +namespace A +{ + struct Struct1 + { + uint data; + }; + + Struct1 myFunc(Struct1 inputS1) + { + Struct1 s1; + s1.data = inputS1.data + 2U; + return s1; + } +}; + + +A::Struct1 myFunc(A::Struct1 inputS1) +{ + A::Struct1 s1; + s1.data = inputS1.data + 5U; + return s1; +} + +namespace A +{ + struct Struct2 + { + Struct1 s1; + } + + Struct2 myFunc(Struct2 inputS2) + { + Struct2 s2; + // We want to cover a corner case in our compiler where: + // when looking up "myFunc", the compiler should find + // Struct1 A::myFunc(Struct1 inputS1) + // and it won't be ambiguous with the global "myFunc". + s2.s1 = myFunc(inputS2.s1); + return s2; + } +}; + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + using namespace A; + + Struct2<10> input = {threadID.x}; + + Struct2<20> output; + output = myFunc<10, 20>(input); + outputBuffer[0] = output.s1.data; + + // BUF: 2 +} diff --git a/tests/bugs/overload-ambiguous.slang b/tests/bugs/overload-ambiguous.slang index 1b74cb68c2..d764f72e42 100644 --- a/tests/bugs/overload-ambiguous.slang +++ b/tests/bugs/overload-ambiguous.slang @@ -6,7 +6,7 @@ //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj -//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -34,7 +34,18 @@ struct DataObtainer } } -RWStructuredBuffer output; +uint myFunc(uint a) +{ + return a + 1u; +} + +__generic +uint myFunc(T a) +{ + uint b = __intCast(a); + return b + 2u; +} + [numthreads(1, 1, 1)] [shader("compute")] @@ -43,6 +54,10 @@ void computeMain(uint3 threadID: SV_DispatchThreadID) DataObtainer obtainer = {2u}; outputBuffer[0] = obtainer.getValue(); outputBuffer[1] = obtainer.getValue2(); + + uint a = 1u; + outputBuffer[2] = myFunc(a); // will call myFunc(uint) which more specialized // BUF: 2 // BUF-NEXT: 1 + // BUF-NEXT: 2 } diff --git a/tests/diagnostics/inout-never-written.slang b/tests/diagnostics/inout-never-written.slang deleted file mode 100644 index f4d4bce7e1..0000000000 --- a/tests/diagnostics/inout-never-written.slang +++ /dev/null @@ -1,74 +0,0 @@ -//TEST:SIMPLE(filecheck=CHK): -target spirv - -struct State -{ - float3 v; - float3 n; - int rnd; -}; - -//CHK-DAG: ([[# @LINE + 1]]): warning 41022: inout parameter 'x' is never written to -void int_never_assigned(inout int x) {} - -//CHK-DAG: ([[# @LINE + 1]]): warning 41022: inout parameter 'state' is never written to -void state_never_assigned(inout State state, inout float v) -{ - v = state.v.x; -} - -void state_assigned(inout State state) -{ - state.rnd = (int) dot(state.v, state.n); -} - -struct A -{ - int state; - - //CHK-DAG: ([[# @LINE + 1]]): warning 41023: method marked `[mutable]` but never modifies `this` - [mutating] int next() { return state; } - - [mutating] int progress() - { - unmodified(state); - return state; - } -}; - -__generic -struct B -{ - int state; - - //CHK-DAG: ([[# @LINE + 1]]): warning 41023: method marked `[mutable]` but never modifies `this` - [mutating] int next() { return state; } -}; - -// Sometimes an inOutImplicitCast is done, -// this needs to be tracked as an alias; -// none of the following functions should -// generate warnings -uint lcg(inout uint prev) -{ - const uint LCG_A = 1664525u; - const uint LCG_C = 1013904223u; - prev = (LCG_A * prev + LCG_C); - return prev & 0x00FFFFFF; -} - -float rnd(inout uint prev) -{ - return ((float) lcg(prev) / (float) 0x01000000); -} - -float3 sample(inout int seed) -{ - float3 xi; - xi.x = rnd(seed); - xi.y = rnd(seed); - xi.z = rnd(seed); - return xi.z; -} - -//CHK-NOT: warning 41022 -//CHK-NOT: warning 41023 diff --git a/tests/diagnostics/overload-ambiguous.slang b/tests/diagnostics/overload-ambiguous.slang new file mode 100644 index 0000000000..0c8f7bd216 --- /dev/null +++ b/tests/diagnostics/overload-ambiguous.slang @@ -0,0 +1,45 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): +RWStructuredBuffer outputBuffer; + +namespace A1 +{ + uint func() + { + return 1u; + } + + namespace A2 + { + uint func() + { + return 2u; + } + } +} +namespace B1 +{ + uint func() + { + return 4u; + } +} + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + using namespace A1; + using namespace A1::A2; + using namespace B1; + using namespace C1; + + // Only A1::func() and B1::func() will cause ambiguity because the distance from + // the reference site to those two functions declaration are the same. + outputBuffer[0] = func(); + // CHECK-NOT: {{.*}}A2::func() -> uint + // CHECK: ambiguous call to 'func' with arguments of type () + // CHECK: candidate: func B1::func() -> uint + // CHECK: candidate: func A1::func() -> uint +} diff --git a/tests/language-feature/generics/irwarray.slang b/tests/language-feature/generics/irwarray.slang new file mode 100644 index 0000000000..47109f7b08 --- /dev/null +++ b/tests/language-feature/generics/irwarray.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +void writeToArray>(inout T array, int index, U value) { array[index] = value; } +void writeToBuffer>(T array, int index, U value) { array[index] = value; } +U readFromArray>(T array, int index) { return array[index]; } + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + float arr[3] = { 1.0, 2.0, 3.0 }; + float4 v = float4(1.0, 2.0, 3.0, 4.0); + float2x2 m = float2x2(1.0, 2.0, 3.0, 4.0); + + // CHECK: 1.0 + writeToBuffer(outputBuffer, 0, 1.0f); + + // CHECK: 4.0 + writeToArray(arr, 0, 4.0f); + outputBuffer[1] = readFromArray(arr, 0); + + // CHECK: 3.0 + writeToArray(v, 3, 3.0f); + outputBuffer[2] = readFromArray(v, 3); + + // CHECK: 30.0 + writeToArray(m, 1, float2(10.0f, 20.0f)); + outputBuffer[3] = readFromArray(m, 1).x + readFromArray(m, 1).y; + + writeToBuffer(outputBuffer, 0, readFromArray(outputBuffer, 0)); +} diff --git a/tools/render-test/options.cpp b/tools/render-test/options.cpp index 2879acfaca..dc1850d379 100644 --- a/tools/render-test/options.cpp +++ b/tools/render-test/options.cpp @@ -34,7 +34,7 @@ static rhi::DeviceType _toRenderType(Slang::RenderApiType apiType) case RenderApiType::CPU: return rhi::DeviceType::CPU; case RenderApiType::CUDA: return rhi::DeviceType::CUDA; default: - return rhi::DeviceType::Unknown; + return rhi::DeviceType::Default; } } @@ -244,7 +244,7 @@ static rhi::DeviceType _toRenderType(Slang::RenderApiType apiType) UnownedStringSlice argName = argSlice.tail(1); DeviceType deviceType = _toRenderType(RenderApiUtil::findApiTypeByName(argName)); - if (deviceType != DeviceType::Unknown) + if (deviceType != DeviceType::Default) { outOptions.deviceType = deviceType; continue; @@ -253,7 +253,7 @@ static rhi::DeviceType _toRenderType(Slang::RenderApiType apiType) // Lookup the target language type DeviceType targetLanguageDeviceType = _toRenderType(RenderApiUtil::findImplicitLanguageRenderApiType(argName)); - if (targetLanguageDeviceType != DeviceType::Unknown || argName == "glsl") + if (targetLanguageDeviceType != DeviceType::Default || argName == "glsl") { outOptions.targetLanguageDeviceType = targetLanguageDeviceType; outOptions.inputLanguageID = (argName == "hlsl" || argName == "glsl" || argName == "cpp" || argName == "cxx" || argName == "c") ? InputLanguageID::Native : InputLanguageID::Slang; @@ -266,7 +266,7 @@ static rhi::DeviceType _toRenderType(Slang::RenderApiType apiType) } // If a render option isn't set use defaultRenderType - outOptions.deviceType = (outOptions.deviceType == DeviceType::Unknown) + outOptions.deviceType = (outOptions.deviceType == DeviceType::Default) ? outOptions.targetLanguageDeviceType : outOptions.deviceType; diff --git a/tools/render-test/options.h b/tools/render-test/options.h index 6b0841c7d3..bd5e65a1ae 100644 --- a/tools/render-test/options.h +++ b/tools/render-test/options.h @@ -53,9 +53,9 @@ struct Options ShaderProgramType shaderType = ShaderProgramType::Graphics; /// The renderer type inferred from the target language type. Used if a rendererType is not explicitly set. - DeviceType targetLanguageDeviceType = DeviceType::Unknown; + DeviceType targetLanguageDeviceType = DeviceType::Default; /// The set render type - DeviceType deviceType = DeviceType::Unknown; + DeviceType deviceType = DeviceType::Default; InputLanguageID inputLanguageID = InputLanguageID::Slang; SlangSourceLanguage sourceLanguage = SLANG_SOURCE_LANGUAGE_UNKNOWN; diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 5712485cc4..370225c7ce 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -127,18 +127,18 @@ class RenderTestApp IDevice* m_device; ComPtr m_queue; ComPtr m_transientHeap; - ComPtr m_renderPass; ComPtr m_inputLayout; - ComPtr m_vertexBuffer; + ComPtr m_vertexBuffer; ComPtr m_shaderProgram; - ComPtr m_pipelineState; - ComPtr m_framebufferLayout; - ComPtr m_framebuffer; - ComPtr m_colorBuffer; + ComPtr m_pipeline; + ComPtr m_depthBuffer; + ComPtr m_depthBufferView; + ComPtr m_colorBuffer; + ComPtr m_colorBufferView; - ComPtr m_blasBuffer; + ComPtr m_blasBuffer; ComPtr m_bottomLevelAccelerationStructure; - ComPtr m_tlasBuffer; + ComPtr m_tlasBuffer; ComPtr m_topLevelAccelerationStructure; ShaderCompilerUtil::OutputAndLayout m_compilationOutput; @@ -212,10 +212,10 @@ struct AssignValsFromLayoutContext for (size_t i = bufferData.getCount(); i < bufferSize / sizeof(uint32_t); i++) bufferData.add(0); - ComPtr bufferResource; - SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource(srcBuffer, /*entry.isOutput,*/ bufferSize, bufferData.getBuffer(), device, bufferResource)); + ComPtr bufferResource; + SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBuffer(srcBuffer, /*entry.isOutput,*/ bufferSize, bufferData.getBuffer(), device, bufferResource)); - ComPtr counterResource; + ComPtr counterResource; const auto explicitCounterCursor = dstCursor.getExplicitCounter(); if(srcBuffer.counter != ~0u) { @@ -238,7 +238,7 @@ struct AssignValsFromLayoutContext 1, Format::Unknown, }; - SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource( + SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBuffer( counterBufferDesc, sizeof(srcBuffer.counter), &srcBuffer.counter, @@ -254,11 +254,14 @@ struct AssignValsFromLayoutContext return SLANG_E_INVALID_ARG; } - IResourceView::Desc viewDesc = {}; - viewDesc.type = IResourceView::Type::UnorderedAccess; - viewDesc.format = srcBuffer.format; - auto bufferView = device->createBufferView(bufferResource, counterResource, viewDesc); - dstCursor.setResource(bufferView); + if (counterResource) + { + dstCursor.setBinding(Binding(bufferResource, counterResource)); + } + else + { + dstCursor.setBinding(bufferResource); + } maybeAddOutput(dstCursor, srcVal, bufferResource); return SLANG_OK; @@ -269,19 +272,13 @@ struct AssignValsFromLayoutContext auto& textureEntry = srcVal->textureVal; auto& samplerEntry = srcVal->samplerVal; - ComPtr texture; - SLANG_RETURN_ON_FAIL(ShaderRendererUtil::generateTextureResource( + ComPtr texture; + SLANG_RETURN_ON_FAIL(ShaderRendererUtil::generateTexture( textureEntry->textureDesc, ResourceState::ShaderResource, device, texture)); - auto sampler = _createSamplerState(device, samplerEntry->samplerDesc); - - IResourceView::Desc viewDesc = {}; - viewDesc.type = IResourceView::Type::ShaderResource; - auto textureView = device->createTextureView( - texture, - viewDesc); + auto sampler = _createSampler(device, samplerEntry->samplerDesc); - dstCursor.setCombinedTextureSampler(textureView, sampler); + dstCursor.setBinding(Binding(texture, sampler)); maybeAddOutput(dstCursor, srcVal, texture); return SLANG_OK; @@ -289,41 +286,23 @@ struct AssignValsFromLayoutContext SlangResult assignTexture(ShaderCursor const& dstCursor, ShaderInputLayout::TextureVal* srcVal) { - ComPtr texture; - ResourceState defaultState = ResourceState::ShaderResource; - IResourceView::Type viewType = IResourceView::Type::ShaderResource; + ComPtr texture; + ResourceState defaultState = srcVal->textureDesc.isRWTexture ? + ResourceState::UnorderedAccess : ResourceState::ShaderResource; - if (srcVal->textureDesc.isRWTexture) - { - defaultState = ResourceState::UnorderedAccess; - viewType = IResourceView::Type::UnorderedAccess; - } - - SLANG_RETURN_ON_FAIL(ShaderRendererUtil::generateTextureResource( + SLANG_RETURN_ON_FAIL(ShaderRendererUtil::generateTexture( srcVal->textureDesc, defaultState, device, texture)); - IResourceView::Desc viewDesc = {}; - viewDesc.type = viewType; - viewDesc.format = texture->getDesc()->format; - auto textureView = device->createTextureView( - texture, - viewDesc); - - if (!textureView) - { - return SLANG_FAIL; - } - - dstCursor.setResource(textureView); + dstCursor.setBinding(texture); maybeAddOutput(dstCursor, srcVal, texture); return SLANG_OK; } SlangResult assignSampler(ShaderCursor const& dstCursor, ShaderInputLayout::SamplerVal* srcVal) { - auto sampler = _createSamplerState(device, srcVal->samplerDesc); + auto sampler = _createSampler(device, srcVal->samplerDesc); - dstCursor.setSampler(sampler); + dstCursor.setBinding(sampler); return SLANG_OK; } @@ -437,7 +416,7 @@ struct AssignValsFromLayoutContext ShaderCursor const& dstCursor, ShaderInputLayout::AccelerationStructureVal* srcVal) { - dstCursor.setResource(accelerationStructure); + dstCursor.setBinding(accelerationStructure); return SLANG_OK; } @@ -511,7 +490,7 @@ Result RenderTestApp::applyBinding(PipelineType pipelineType, ICommandEncoder* e case PipelineType::Compute: { IComputeCommandEncoder* computeEncoder = static_cast(encoder); - auto rootObject = computeEncoder->bindPipeline(m_pipelineState); + auto rootObject = computeEncoder->bindPipeline(m_pipeline); SLANG_RETURN_ON_FAIL(_assignVarsFromLayout( m_device, slangSession, @@ -525,7 +504,7 @@ Result RenderTestApp::applyBinding(PipelineType pipelineType, ICommandEncoder* e case PipelineType::Graphics: { IRenderCommandEncoder* renderEncoder = static_cast(encoder); - auto rootObject = renderEncoder->bindPipeline(m_pipelineState); + auto rootObject = renderEncoder->bindPipeline(m_pipeline); SLANG_RETURN_ON_FAIL(_assignVarsFromLayout( m_device, slangSession, @@ -559,7 +538,7 @@ SlangResult RenderTestApp::initialize( // Once the shaders have been compiled we load them via the underlying API. // ComPtr outDiagnostics; - auto result = device->createProgram(m_compilationOutput.output.desc, m_shaderProgram.writeRef(), outDiagnostics.writeRef()); + auto result = device->createShaderProgram(m_compilationOutput.output.desc, m_shaderProgram.writeRef(), outDiagnostics.writeRef()); // If there was a failure creating a program, we can't continue // Special case SLANG_E_NOT_AVAILABLE error code to make it a failure, @@ -585,10 +564,10 @@ SlangResult RenderTestApp::initialize( case Options::ShaderProgramType::Compute: { - ComputePipelineStateDesc desc; + ComputePipelineDesc desc; desc.program = m_shaderProgram; - m_pipelineState = device->createComputePipelineState(desc); + m_pipeline = device->createComputePipeline(desc); } break; @@ -614,37 +593,44 @@ SlangResult RenderTestApp::initialize( SLANG_RETURN_ON_FAIL(device->createInputLayout( sizeof(Vertex), inputElements, SLANG_COUNT_OF(inputElements), inputLayout.writeRef())); - IBufferResource::Desc vertexBufferDesc; - vertexBufferDesc.type = IResource::Type::Buffer; - vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex); + BufferDesc vertexBufferDesc; + vertexBufferDesc.size = kVertexCount * sizeof(Vertex); vertexBufferDesc.memoryType = MemoryType::Upload; + vertexBufferDesc.usage = BufferUsage::VertexBuffer; vertexBufferDesc.defaultState = ResourceState::VertexBuffer; - vertexBufferDesc.allowedStates = ResourceStateSet(ResourceState::VertexBuffer); - SLANG_RETURN_ON_FAIL(device->createBufferResource( + SLANG_RETURN_ON_FAIL(device->createBuffer( vertexBufferDesc, kVertexData, m_vertexBuffer.writeRef())); - GraphicsPipelineStateDesc desc; + ColorTargetState colorTarget; + colorTarget.format = Format::R8G8B8A8_UNORM; + RenderPipelineDesc desc; desc.program = m_shaderProgram; desc.inputLayout = inputLayout; - desc.framebufferLayout = m_framebufferLayout; - m_pipelineState = device->createGraphicsPipelineState(desc); + desc.targets = &colorTarget; + desc.targetCount = 1; + desc.depthStencil.format = Format::D32_FLOAT; + m_pipeline = device->createRenderPipeline(desc); } break; case Options::ShaderProgramType::GraphicsMeshCompute: case Options::ShaderProgramType::GraphicsTaskMeshCompute: { - GraphicsPipelineStateDesc desc; + ColorTargetState colorTarget; + colorTarget.format = Format::R8G8B8A8_UNORM; + RenderPipelineDesc desc; desc.program = m_shaderProgram; - desc.framebufferLayout = m_framebufferLayout; - m_pipelineState = device->createGraphicsPipelineState(desc); + desc.targets = &colorTarget; + desc.targetCount = 1; + desc.depthStencil.format = Format::D32_FLOAT; + m_pipeline = device->createRenderPipeline(desc); } } } // If success must have a pipeline state - return m_pipelineState ? SLANG_OK : SLANG_FAIL; + return m_pipeline ? SLANG_OK : SLANG_FAIL; } Result RenderTestApp::_initializeShaders( @@ -655,7 +641,7 @@ Result RenderTestApp::_initializeShaders( { SLANG_RETURN_ON_FAIL(ShaderCompilerUtil::compileWithLayout(device->getSlangSession()->getGlobalSession(), m_options, input, m_compilationOutput)); m_shaderInputLayout = m_compilationOutput.layout; - m_shaderProgram = device->createProgram(m_compilationOutput.output.desc); + m_shaderProgram = device->createShaderProgram(m_compilationOutput.output.desc); return m_shaderProgram ? SLANG_OK : SLANG_FAIL; } @@ -670,102 +656,54 @@ void RenderTestApp::_initializeRenderPass() m_queue = m_device->createCommandQueue(queueDesc); SLANG_ASSERT(m_queue); - rhi::ITextureResource::Desc depthBufferDesc; - depthBufferDesc.type = IResource::Type::Texture2D; + rhi::TextureDesc depthBufferDesc; + depthBufferDesc.type = TextureType::Texture2D; depthBufferDesc.size.width = gWindowWidth; depthBufferDesc.size.height = gWindowHeight; depthBufferDesc.size.depth = 1; depthBufferDesc.numMipLevels = 1; depthBufferDesc.format = Format::D32_FLOAT; + depthBufferDesc.usage = TextureUsage::DepthWrite; depthBufferDesc.defaultState = ResourceState::DepthWrite; - depthBufferDesc.allowedStates = ResourceState::DepthWrite; - - ComPtr depthBufferResource = - m_device->createTextureResource(depthBufferDesc, nullptr); - SLANG_ASSERT(depthBufferResource); + m_depthBuffer = m_device->createTexture(depthBufferDesc, nullptr); + SLANG_ASSERT(m_depthBuffer); + m_depthBufferView = m_device->createTextureView(m_depthBuffer, {}); + SLANG_ASSERT(m_depthBufferView); - rhi::ITextureResource::Desc colorBufferDesc; - colorBufferDesc.type = IResource::Type::Texture2D; + rhi::TextureDesc colorBufferDesc; + colorBufferDesc.type = TextureType::Texture2D; colorBufferDesc.size.width = gWindowWidth; colorBufferDesc.size.height = gWindowHeight; colorBufferDesc.size.depth = 1; colorBufferDesc.numMipLevels = 1; colorBufferDesc.format = Format::R8G8B8A8_UNORM; + colorBufferDesc.usage = TextureUsage::RenderTarget; colorBufferDesc.defaultState = ResourceState::RenderTarget; - colorBufferDesc.allowedStates = ResourceState::RenderTarget; - m_colorBuffer = m_device->createTextureResource(colorBufferDesc, nullptr); + m_colorBuffer = m_device->createTexture(colorBufferDesc, nullptr); SLANG_ASSERT(m_colorBuffer); - - rhi::IResourceView::Desc colorBufferViewDesc = {}; - memset(&colorBufferViewDesc, 0, sizeof(colorBufferViewDesc)); - colorBufferViewDesc.format = rhi::Format::R8G8B8A8_UNORM; - colorBufferViewDesc.renderTarget.shape = rhi::IResource::Type::Texture2D; - colorBufferViewDesc.type = rhi::IResourceView::Type::RenderTarget; - ComPtr rtv = - m_device->createTextureView(m_colorBuffer.get(), colorBufferViewDesc); - SLANG_ASSERT(rtv); - - rhi::IResourceView::Desc depthBufferViewDesc = {}; - memset(&depthBufferViewDesc, 0, sizeof(depthBufferViewDesc)); - depthBufferViewDesc.format = rhi::Format::D32_FLOAT; - depthBufferViewDesc.renderTarget.shape = rhi::IResource::Type::Texture2D; - depthBufferViewDesc.type = rhi::IResourceView::Type::DepthStencil; - ComPtr dsv = - m_device->createTextureView(depthBufferResource.get(), depthBufferViewDesc); - SLANG_ASSERT(dsv); - - IFramebufferLayout::TargetLayout colorTarget = {rhi::Format::R8G8B8A8_UNORM, 1}; - IFramebufferLayout::TargetLayout depthTarget = {rhi::Format::D32_FLOAT, 1}; - rhi::IFramebufferLayout::Desc framebufferLayoutDesc; - framebufferLayoutDesc.renderTargetCount = 1; - framebufferLayoutDesc.renderTargets = &colorTarget; - framebufferLayoutDesc.depthStencil = &depthTarget; - m_device->createFramebufferLayout(framebufferLayoutDesc, m_framebufferLayout.writeRef()); - - rhi::IFramebuffer::Desc framebufferDesc; - framebufferDesc.renderTargetCount = 1; - framebufferDesc.depthStencilView = dsv.get(); - framebufferDesc.renderTargetViews = rtv.readRef(); - framebufferDesc.layout = m_framebufferLayout; - m_device->createFramebuffer(framebufferDesc, m_framebuffer.writeRef()); - - IRenderPassLayout::Desc renderPassDesc = {}; - renderPassDesc.framebufferLayout = m_framebufferLayout; - renderPassDesc.renderTargetCount = 1; - IRenderPassLayout::TargetAccessDesc renderTargetAccess = {}; - IRenderPassLayout::TargetAccessDesc depthStencilAccess = {}; - renderTargetAccess.loadOp = IRenderPassLayout::TargetLoadOp::Clear; - renderTargetAccess.storeOp = IRenderPassLayout::TargetStoreOp::Store; - renderTargetAccess.initialState = ResourceState::Undefined; - renderTargetAccess.finalState = ResourceState::RenderTarget; - depthStencilAccess.loadOp = IRenderPassLayout::TargetLoadOp::Clear; - depthStencilAccess.storeOp = IRenderPassLayout::TargetStoreOp::Store; - depthStencilAccess.initialState = ResourceState::Undefined; - depthStencilAccess.finalState = ResourceState::DepthWrite; - renderPassDesc.renderTargetAccess = &renderTargetAccess; - renderPassDesc.depthStencilAccess = &depthStencilAccess; - m_device->createRenderPassLayout(renderPassDesc, m_renderPass.writeRef()); + m_colorBufferView = m_device->createTextureView(m_colorBuffer, {}); + SLANG_ASSERT(m_colorBufferView); } void RenderTestApp::_initializeAccelerationStructure() { if (!m_device->hasFeature("ray-tracing")) return; - IBufferResource::Desc vertexBufferDesc = {}; - vertexBufferDesc.type = IResource::Type::Buffer; - vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex); - vertexBufferDesc.defaultState = ResourceState::ShaderResource; - ComPtr vertexBuffer = - m_device->createBufferResource(vertexBufferDesc, &kVertexData[0]); - - IBufferResource::Desc transformBufferDesc = {}; - transformBufferDesc.type = IResource::Type::Buffer; - transformBufferDesc.sizeInBytes = sizeof(float) * 12; - transformBufferDesc.defaultState = ResourceState::ShaderResource; + BufferDesc vertexBufferDesc = {}; + vertexBufferDesc.size = kVertexCount * sizeof(Vertex); + vertexBufferDesc.usage = BufferUsage::AccelerationStructureBuildInput; + vertexBufferDesc.defaultState = ResourceState::AccelerationStructureBuildInput; + ComPtr vertexBuffer = + m_device->createBuffer(vertexBufferDesc, &kVertexData[0]); + + BufferDesc transformBufferDesc = {}; + transformBufferDesc.size = sizeof(float) * 12; + transformBufferDesc.usage = BufferUsage::AccelerationStructureBuildInput; + transformBufferDesc.defaultState = ResourceState::AccelerationStructureBuildInput; float transformData[12] = { 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f}; - ComPtr transformBuffer = - m_device->createBufferResource(transformBufferDesc, &transformData); + ComPtr transformBuffer = + m_device->createBuffer(transformBufferDesc, &transformData); // Build bottom level acceleration structure. { @@ -792,20 +730,20 @@ void RenderTestApp::_initializeAccelerationStructure() m_device->getAccelerationStructurePrebuildInfo( accelerationStructureBuildInputs, &accelerationStructurePrebuildInfo); // Allocate buffers for acceleration structure. - IBufferResource::Desc asDraftBufferDesc = {}; - asDraftBufferDesc.type = IResource::Type::Buffer; + BufferDesc asDraftBufferDesc = {}; + asDraftBufferDesc.usage = BufferUsage::AccelerationStructure; asDraftBufferDesc.defaultState = ResourceState::AccelerationStructure; - asDraftBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.resultDataMaxSize; - ComPtr draftBuffer = m_device->createBufferResource(asDraftBufferDesc); - IBufferResource::Desc scratchBufferDesc = {}; - scratchBufferDesc.type = IResource::Type::Buffer; + asDraftBufferDesc.size = accelerationStructurePrebuildInfo.resultDataMaxSize; + ComPtr draftBuffer = m_device->createBuffer(asDraftBufferDesc); + BufferDesc scratchBufferDesc = {}; + scratchBufferDesc.usage = BufferUsage::UnorderedAccess; scratchBufferDesc.defaultState = ResourceState::UnorderedAccess; - scratchBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.scratchDataSize; - ComPtr scratchBuffer = m_device->createBufferResource(scratchBufferDesc); + scratchBufferDesc.size = accelerationStructurePrebuildInfo.scratchDataSize; + ComPtr scratchBuffer = m_device->createBuffer(scratchBufferDesc); // Build acceleration structure. ComPtr compactedSizeQuery; - IQueryPool::Desc queryPoolDesc = {}; + QueryPoolDesc queryPoolDesc = {}; queryPoolDesc.count = 1; queryPoolDesc.type = QueryType::AccelerationStructureCompactedSize; m_device->createQueryPool(queryPoolDesc, compactedSizeQuery.writeRef()); @@ -837,11 +775,11 @@ void RenderTestApp::_initializeAccelerationStructure() uint64_t compactedSize = 0; compactedSizeQuery->getResult(0, 1, &compactedSize); - IBufferResource::Desc asBufferDesc = {}; - asBufferDesc.type = IResource::Type::Buffer; + BufferDesc asBufferDesc = {}; + asBufferDesc.usage = BufferUsage::AccelerationStructure; asBufferDesc.defaultState = ResourceState::AccelerationStructure; - asBufferDesc.sizeInBytes = (Size)compactedSize; - m_blasBuffer = m_device->createBufferResource(asBufferDesc); + asBufferDesc.size = (Size)compactedSize; + m_blasBuffer = m_device->createBuffer(asBufferDesc); IAccelerationStructure::CreateDesc createDesc; createDesc.buffer = m_blasBuffer; createDesc.kind = IAccelerationStructure::Kind::BottomLevel; @@ -874,13 +812,13 @@ void RenderTestApp::_initializeAccelerationStructure() 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f}; memcpy(&instanceDescs[0].transform[0][0], transformMatrix, sizeof(float) * 12); - IBufferResource::Desc instanceBufferDesc = {}; - instanceBufferDesc.type = IResource::Type::Buffer; - instanceBufferDesc.sizeInBytes = + BufferDesc instanceBufferDesc = {}; + instanceBufferDesc.size = instanceDescs.getCount() * sizeof(IAccelerationStructure::InstanceDesc); + instanceBufferDesc.usage = BufferUsage::AccelerationStructureBuildInput; instanceBufferDesc.defaultState = ResourceState::AccelerationStructureBuildInput; - ComPtr instanceBuffer = - m_device->createBufferResource(instanceBufferDesc, instanceDescs.getBuffer()); + ComPtr instanceBuffer = + m_device->createBuffer(instanceBufferDesc, instanceDescs.getBuffer()); IAccelerationStructure::BuildInputs accelerationStructureBuildInputs = {}; IAccelerationStructure::PrebuildInfo accelerationStructurePrebuildInfo = {}; @@ -892,17 +830,17 @@ void RenderTestApp::_initializeAccelerationStructure() m_device->getAccelerationStructurePrebuildInfo( accelerationStructureBuildInputs, &accelerationStructurePrebuildInfo); - IBufferResource::Desc asBufferDesc = {}; - asBufferDesc.type = IResource::Type::Buffer; + BufferDesc asBufferDesc = {}; + asBufferDesc.usage = BufferUsage::AccelerationStructure; asBufferDesc.defaultState = ResourceState::AccelerationStructure; - asBufferDesc.sizeInBytes = (size_t)accelerationStructurePrebuildInfo.resultDataMaxSize; - m_tlasBuffer = m_device->createBufferResource(asBufferDesc); + asBufferDesc.size = (size_t)accelerationStructurePrebuildInfo.resultDataMaxSize; + m_tlasBuffer = m_device->createBuffer(asBufferDesc); - IBufferResource::Desc scratchBufferDesc = {}; - scratchBufferDesc.type = IResource::Type::Buffer; + BufferDesc scratchBufferDesc = {}; + scratchBufferDesc.usage = BufferUsage::UnorderedAccess; scratchBufferDesc.defaultState = ResourceState::UnorderedAccess; - scratchBufferDesc.sizeInBytes = (size_t)accelerationStructurePrebuildInfo.scratchDataSize; - ComPtr scratchBuffer = m_device->createBufferResource(scratchBufferDesc); + scratchBufferDesc.size = (size_t)accelerationStructurePrebuildInfo.scratchDataSize; + ComPtr scratchBuffer = m_device->createBuffer(scratchBufferDesc); IAccelerationStructure::CreateDesc createDesc = {}; createDesc.buffer = m_tlasBuffer; @@ -987,48 +925,16 @@ Result RenderTestApp::writeBindingOutput(const String& fileName) for(auto outputItem : m_outputPlan.items) { auto resource = outputItem.resource; - if (resource && resource->getType() == IResource::Type::Buffer) + IBuffer* buffer = nullptr; + resource->queryInterface(IBuffer::getTypeGuid(), (void**)&buffer); + if (buffer) { - IBufferResource* bufferResource = static_cast(resource.get()); - auto bufferDesc = *bufferResource->getDesc(); - const size_t bufferSize = bufferDesc.sizeInBytes; + const BufferDesc& bufferDesc = buffer->getDesc(); + const size_t bufferSize = bufferDesc.size; ComPtr blob; - if(bufferDesc.memoryType == MemoryType::ReadBack) - { - // The buffer is already allocated for CPU access, so we can read it back directly. - // - m_device->readBufferResource(bufferResource, 0, bufferSize, blob.writeRef()); - } - else - { - // The buffer is not CPU-readable, so we will copy it using a staging buffer. - - auto stagingBufferDesc = bufferDesc; - stagingBufferDesc.memoryType = MemoryType::ReadBack; - stagingBufferDesc.allowedStates = - ResourceStateSet(ResourceState::CopyDestination, ResourceState::CopySource); - stagingBufferDesc.defaultState = ResourceState::CopyDestination; - - ComPtr stagingBuffer; - SLANG_RETURN_ON_FAIL(m_device->createBufferResource(stagingBufferDesc, nullptr, stagingBuffer.writeRef())); - - ComPtr commandBuffer; - SLANG_RETURN_ON_FAIL( - m_transientHeap->createCommandBuffer(commandBuffer.writeRef())); - - IResourceCommandEncoder* encoder = nullptr; - commandBuffer->encodeResourceCommands(&encoder); - encoder->copyBuffer(stagingBuffer, 0, bufferResource, 0, bufferSize); - encoder->endEncoding(); - - commandBuffer->close(); - m_queue->executeCommandBuffer(commandBuffer); - m_transientHeap->finish(); - m_transientHeap->synchronizeAndReset(); - - SLANG_RETURN_ON_FAIL(m_device->readBufferResource(stagingBuffer, 0, bufferSize, blob.writeRef())); - } + m_device->readBuffer(buffer, 0, bufferSize, blob.writeRef()); + buffer->release(); if (!blob) { @@ -1054,7 +960,7 @@ Result RenderTestApp::writeScreen(const String& filename) { size_t rowPitch, pixelSize; ComPtr blob; - SLANG_RETURN_ON_FAIL(m_device->readTextureResource( + SLANG_RETURN_ON_FAIL(m_device->readTexture( m_colorBuffer, ResourceState::RenderTarget, blob.writeRef(), &rowPitch, &pixelSize)); auto bufferSize = blob->getBufferSize(); uint32_t width = static_cast(rowPitch / pixelSize); @@ -1073,7 +979,24 @@ Result RenderTestApp::update() } else { - auto encoder = commandBuffer->encodeRenderCommands(m_renderPass, m_framebuffer); + RenderPassColorAttachment colorAttachment = {}; + colorAttachment.view = m_colorBufferView; + colorAttachment.loadOp = LoadOp::Clear; + colorAttachment.storeOp = StoreOp::Store; + colorAttachment.initialState = ResourceState::Undefined; + colorAttachment.finalState = ResourceState::RenderTarget; + RenderPassDepthStencilAttachment depthStencilAttachment = {}; + depthStencilAttachment.view = m_depthBufferView; + depthStencilAttachment.depthLoadOp = LoadOp::Clear; + depthStencilAttachment.depthStoreOp = StoreOp::Store; + depthStencilAttachment.initialState = ResourceState::Undefined; + depthStencilAttachment.finalState = ResourceState::DepthWrite; + RenderPassDesc renderPass = {}; + renderPass.colorAttachments = &colorAttachment; + renderPass.colorAttachmentCount = 1; + renderPass.depthStencilAttachment = &depthStencilAttachment; + + auto encoder = commandBuffer->encodeRenderCommands(renderPass); rhi::Viewport viewport = {}; viewport.maxZ = 1.0f; viewport.extentX = (float)gWindowWidth; @@ -1113,7 +1036,7 @@ Result RenderTestApp::update() if (binding.resource && binding.resource->isBuffer()) { BufferResource* bufferResource = static_cast(binding.resource.Ptr()); - const size_t bufferSize = bufferResource->getDesc().sizeInBytes; + const size_t bufferSize = bufferResource->getDesc().size; unsigned int* ptr = (unsigned int*)m_renderer->map(bufferResource, MapFlavor::HostRead); if (!ptr) { @@ -1252,7 +1175,7 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi // Parse command-line options SLANG_RETURN_ON_FAIL(Options::parse(argcIn, argvIn, StdWriters::getError(), options)); - if (options.deviceType == DeviceType::Unknown) + if (options.deviceType == DeviceType::Default) { return SLANG_OK; } diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 97b82e1e97..61951db50c 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -19,9 +19,15 @@ namespace renderer_test Format _getFormatFromName(const UnownedStringSlice& slice) { -#define SLANG_FORMAT_CASE(name, blockSizeInBytes, pixelsPerBlock) if (slice == #name) return Format::name; else - - SLANG_RHI_FORMAT(SLANG_FORMAT_CASE) + for (int i = 0; i < int(Format::_Count); ++i) + { + FormatInfo info; + rhiGetFormatInfo(Format(i), &info); + if (slice == info.name) + { + return Format(i); + } + } return Format::Unknown; } diff --git a/tools/render-test/shader-renderer-util.cpp b/tools/render-test/shader-renderer-util.cpp index c2f7583a7a..5448979b42 100644 --- a/tools/render-test/shader-renderer-util.cpp +++ b/tools/render-test/shader-renderer-util.cpp @@ -13,49 +13,26 @@ inline int calcMipSize(int size, int level) return size > 0 ? size : 1; } -inline ITextureResource::Extents calcMipSize(ITextureResource::Extents size, int mipLevel) +inline Extents calcMipSize(Extents size, int mipLevel) { - ITextureResource::Extents rs; + Extents rs; rs.width = calcMipSize(size.width, mipLevel); rs.height = calcMipSize(size.height, mipLevel); rs.depth = calcMipSize(size.depth, mipLevel); return rs; } -/// Calculate the effective array size - in essence the amount if mip map sets needed. -/// In practice takes into account if the arraySize is 0 (it's not an array, but it will still have -/// at least one mip set) and if the type is a cubemap (multiplies the amount of mip sets by 6) -inline int calcEffectiveArraySize(const ITextureResource::Desc& desc) -{ - const int arrSize = (desc.arraySize > 0) ? desc.arraySize : 1; - - switch (desc.type) - { - case IResource::Type::Texture1D: // fallthru - case IResource::Type::Texture2D: - { - return arrSize; - } - case IResource::Type::TextureCube: - return arrSize * 6; - case IResource::Type::Texture3D: - return 1; - default: - return 0; - } -} - /// Given the type works out the maximum dimension size -inline int calcMaxDimension(ITextureResource::Extents size, IResource::Type type) +inline int calcMaxDimension(Extents size, TextureType type) { switch (type) { - case IResource::Type::Texture1D: + case TextureType::Texture1D: return size.width; - case IResource::Type::Texture3D: + case TextureType::Texture3D: return Math::Max(Math::Max(size.width, size.height), size.depth); - case IResource::Type::TextureCube: // fallthru - case IResource::Type::Texture2D: + case TextureType::TextureCube: // fallthru + case TextureType::Texture2D: { return Math::Max(size.width, size.height); } @@ -65,115 +42,103 @@ inline int calcMaxDimension(ITextureResource::Extents size, IResource::Type type } /// Given the type, calculates the number of mip maps. 0 on error -inline int calcNumMipLevels(IResource::Type type, ITextureResource::Extents size) +inline int calcNumMipLevels(TextureType type, Extents size) { const int maxDimensionSize = calcMaxDimension(size, type); return (maxDimensionSize > 0) ? (Math::Log2Floor(maxDimensionSize) + 1) : 0; } -/// Calculate the total number of sub resources. 0 on error. -inline int calcNumSubResources(const ITextureResource::Desc& desc) -{ - const int numMipMaps = - (desc.numMipLevels > 0) ? desc.numMipLevels : calcNumMipLevels(desc.type, desc.size); - const int arrSize = (desc.arraySize > 0) ? desc.arraySize : 1; - - switch (desc.type) - { - case IResource::Type::Texture1D: - case IResource::Type::Texture2D: - case IResource::Type::Texture3D: - { - return numMipMaps * arrSize; - } - case IResource::Type::TextureCube: - { - // There are 6 faces to a cubemap - return numMipMaps * arrSize * 6; - } - default: - return 0; - } -} - -/* static */ Result ShaderRendererUtil::generateTextureResource( +/* static */ Result ShaderRendererUtil::generateTexture( const InputTextureDesc& inputDesc, ResourceState defaultState, IDevice* device, - ComPtr& textureOut) + ComPtr& textureOut) { TextureData texData; generateTextureData(texData, inputDesc); - return createTextureResource(inputDesc, texData, defaultState, device, textureOut); + return createTexture(inputDesc, texData, defaultState, device, textureOut); } -/* static */ Result ShaderRendererUtil::createTextureResource( +/* static */ Result ShaderRendererUtil::createTexture( const InputTextureDesc& inputDesc, const TextureData& texData, ResourceState defaultState, IDevice* device, - ComPtr& textureOut) + ComPtr& textureOut) { - ITextureResource::Desc textureResourceDesc = {}; + TextureDesc textureDesc = {}; // Default to R8G8B8A8_UNORM const Format format = (inputDesc.format == Format::Unknown) ? Format::R8G8B8A8_UNORM : inputDesc.format; - textureResourceDesc.sampleDesc = ITextureResource::SampleDesc{inputDesc.sampleCount, 0}; - textureResourceDesc.format = format; - textureResourceDesc.numMipLevels = texData.m_mipLevels; - textureResourceDesc.arraySize = inputDesc.arrayLength; - textureResourceDesc.allowedStates = - ResourceStateSet(defaultState, ResourceState::CopyDestination, ResourceState::CopySource); - textureResourceDesc.defaultState = defaultState; + textureDesc.sampleCount = inputDesc.sampleCount; + textureDesc.format = format; + textureDesc.numMipLevels = texData.m_mipLevels; + textureDesc.arrayLength = inputDesc.arrayLength > 0 ? inputDesc.arrayLength : 1; + textureDesc.usage = TextureUsage::CopyDestination | TextureUsage::CopySource; + switch (defaultState) + { + case ResourceState::ShaderResource: + textureDesc.usage |= TextureUsage::ShaderResource; + break; + case ResourceState::UnorderedAccess: + textureDesc.usage |= TextureUsage::UnorderedAccess; + break; + default: + return SLANG_FAIL; + } + textureDesc.defaultState = defaultState; // It's the same size in all dimensions switch (inputDesc.dimension) { case 1: { - textureResourceDesc.type = IResource::Type::Texture1D; - textureResourceDesc.size.width = inputDesc.size; - textureResourceDesc.size.height = 1; - textureResourceDesc.size.depth = 1; + textureDesc.type = TextureType::Texture1D; + textureDesc.size.width = inputDesc.size; + textureDesc.size.height = 1; + textureDesc.size.depth = 1; break; } case 2: { - textureResourceDesc.type = inputDesc.isCube ? IResource::Type::TextureCube : IResource::Type::Texture2D; - textureResourceDesc.size.width = inputDesc.size; - textureResourceDesc.size.height = inputDesc.size; - textureResourceDesc.size.depth = 1; + textureDesc.type = inputDesc.isCube ? TextureType::TextureCube : TextureType::Texture2D; + textureDesc.size.width = inputDesc.size; + textureDesc.size.height = inputDesc.size; + textureDesc.size.depth = 1; break; } case 3: { - textureResourceDesc.type = IResource::Type::Texture3D; - textureResourceDesc.size.width = inputDesc.size; - textureResourceDesc.size.height = inputDesc.size; - textureResourceDesc.size.depth = inputDesc.size; + textureDesc.type = TextureType::Texture3D; + textureDesc.size.width = inputDesc.size; + textureDesc.size.height = inputDesc.size; + textureDesc.size.depth = inputDesc.size; break; } } - const int effectiveArraySize = calcEffectiveArraySize(textureResourceDesc); - const int numSubResources = calcNumSubResources(textureResourceDesc); + if (textureDesc.numMipLevels == 0) + { + textureDesc.numMipLevels = calcNumMipLevels(textureDesc.type, textureDesc.size); + } - List initSubresources; + List initSubresources; + int arrayLayerCount = textureDesc.arrayLength * (textureDesc.type == TextureType::TextureCube ? 6 : 1); int subResourceCounter = 0; - for( int a = 0; a < effectiveArraySize; ++a ) + for( int a = 0; a < arrayLayerCount; ++a ) { - for( int m = 0; m < textureResourceDesc.numMipLevels; ++m ) + for( int m = 0; m < textureDesc.numMipLevels; ++m ) { int subResourceIndex = subResourceCounter++; - const int mipWidth = calcMipSize(textureResourceDesc.size.width, m); - const int mipHeight = calcMipSize(textureResourceDesc.size.height, m); + const int mipWidth = calcMipSize(textureDesc.size.width, m); + const int mipHeight = calcMipSize(textureDesc.size.height, m); auto strideY = mipWidth * sizeof(uint32_t); auto strideZ = mipHeight * strideY; - ITextureResource::SubresourceData subresourceData; + SubresourceData subresourceData; subresourceData.data = texData.m_slices[subResourceIndex].values; subresourceData.strideY = strideY; subresourceData.strideZ = strideZ; @@ -182,31 +147,26 @@ inline int calcNumSubResources(const ITextureResource::Desc& desc) } } - textureOut = device->createTextureResource(textureResourceDesc, initSubresources.getBuffer()); + textureOut = device->createTexture(textureDesc, initSubresources.getBuffer()); return textureOut ? SLANG_OK : SLANG_FAIL; } -/* static */ Result ShaderRendererUtil::createBufferResource( +/* static */ Result ShaderRendererUtil::createBuffer( const InputBufferDesc& inputDesc, size_t bufferSize, const void* initData, IDevice* device, - Slang::ComPtr& bufferOut) + Slang::ComPtr& bufferOut) { - IBufferResource::Desc srcDesc; - srcDesc.type = IResource::Type::Buffer; - srcDesc.sizeInBytes = bufferSize; - srcDesc.format = inputDesc.format; - srcDesc.elementSize = inputDesc.stride; - srcDesc.defaultState = ResourceState::UnorderedAccess; - srcDesc.allowedStates = ResourceStateSet( - ResourceState::CopyDestination, - ResourceState::CopySource, - ResourceState::UnorderedAccess, - ResourceState::ShaderResource); - - ComPtr bufferResource = device->createBufferResource(srcDesc, initData); + BufferDesc bufferDesc; + bufferDesc.size = bufferSize; + bufferDesc.format = inputDesc.format; + bufferDesc.elementSize = inputDesc.stride; + bufferDesc.usage = BufferUsage::CopyDestination | BufferUsage::CopySource | BufferUsage::ShaderResource | BufferUsage::UnorderedAccess; + bufferDesc.defaultState = ResourceState::UnorderedAccess; + + ComPtr bufferResource = device->createBuffer(bufferDesc, initData); if (!bufferResource) { return SLANG_FAIL; @@ -216,21 +176,21 @@ inline int calcNumSubResources(const ITextureResource::Desc& desc) return SLANG_OK; } -static ISamplerState::Desc _calcSamplerDesc(const InputSamplerDesc& srcDesc) +static SamplerDesc _calcSamplerDesc(const InputSamplerDesc& srcDesc) { - ISamplerState::Desc dstDesc; + SamplerDesc samplerDesc; if (srcDesc.isCompareSampler) { - dstDesc.reductionOp = TextureReductionOp::Comparison; - dstDesc.comparisonFunc = ComparisonFunc::Less; + samplerDesc.reductionOp = TextureReductionOp::Comparison; + samplerDesc.comparisonFunc = ComparisonFunc::Less; } - return dstDesc; + return samplerDesc; } -ComPtr _createSamplerState(IDevice* device, +ComPtr _createSampler(IDevice* device, const InputSamplerDesc& srcDesc) { - return device->createSamplerState(_calcSamplerDesc(srcDesc)); + return device->createSampler(_calcSamplerDesc(srcDesc)); } } // renderer_test diff --git a/tools/render-test/shader-renderer-util.h b/tools/render-test/shader-renderer-util.h index 188562fa8f..8d0075f3f2 100644 --- a/tools/render-test/shader-renderer-util.h +++ b/tools/render-test/shader-renderer-util.h @@ -8,33 +8,33 @@ namespace renderer_test { using namespace Slang; -ComPtr _createSamplerState(IDevice* device, const InputSamplerDesc& srcDesc); +ComPtr _createSampler(IDevice* device, const InputSamplerDesc& srcDesc); /// Utility class containing functions that construct items on the renderer using the ShaderInputLayout representation struct ShaderRendererUtil { - /// Generate a texture using the InputTextureDesc and construct a TextureResource using the Renderer with the contents - static Slang::Result generateTextureResource( + /// Generate a texture using the InputTextureDesc and construct a Texture using the Renderer with the contents + static Slang::Result generateTexture( const InputTextureDesc& inputDesc, ResourceState defaultState, IDevice* device, - ComPtr& textureOut); + ComPtr& textureOut); /// Create texture resource using inputDesc, and texData to describe format, and contents - static Slang::Result createTextureResource( + static Slang::Result createTexture( const InputTextureDesc& inputDesc, const TextureData& texData, ResourceState defaultState, IDevice* device, - ComPtr& textureOut); + ComPtr& textureOut); /// Create the BufferResource using the renderer from the contents of inputDesc - static Slang::Result createBufferResource( + static Slang::Result createBuffer( const InputBufferDesc& inputDesc, size_t bufferSize, const void* initData, IDevice* device, - ComPtr& bufferOut); + ComPtr& bufferOut); }; } // renderer_test diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index c819aede2e..1f174dc72b 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -23,36 +23,6 @@ static const char rtEntryPointName[] = "raygenMain"; static const char taskEntryPointName[] = "taskMain"; static const char meshEntryPointName[] = "meshMain"; -rhi::StageType translateStage(SlangStage slangStage) -{ - switch(slangStage) - { - default: - SLANG_ASSERT(!"unhandled case"); - return rhi::StageType::Unknown; - -#define CASE(FROM, TO) \ - case SLANG_STAGE_##FROM: return rhi::StageType::TO - - CASE(VERTEX, Vertex); - CASE(HULL, Hull); - CASE(DOMAIN, Domain); - CASE(GEOMETRY, Geometry); - CASE(FRAGMENT, Fragment); - - CASE(COMPUTE, Compute); - - CASE(RAY_GENERATION, RayGeneration); - CASE(INTERSECTION, Intersection); - CASE(ANY_HIT, AnyHit); - CASE(CLOSEST_HIT, ClosestHit); - CASE(MISS, Miss); - CASE(CALLABLE, Callable); - -#undef CASE - } -} - void ShaderCompilerUtil::Output::set( slang::IComponentType* inSlangProgram) { diff --git a/tools/render-test/slang-support.h b/tools/render-test/slang-support.h index e08440377a..c930089993 100644 --- a/tools/render-test/slang-support.h +++ b/tools/render-test/slang-support.h @@ -10,8 +10,6 @@ namespace renderer_test { -rhi::StageType translateStage(SlangStage slangStage); - struct ShaderCompileRequest { struct SourceInfo @@ -71,7 +69,7 @@ struct ShaderCompilerUtil } ComPtr slangProgram; - IShaderProgram::Desc desc = {}; + ShaderProgramDesc desc = {}; /// Compile request that owns the lifetime of compiled kernel code. ComPtr m_requestForKernels = nullptr;