From 143930999f57dba804847157f10df57d0e5c9ee7 Mon Sep 17 00:00:00 2001 From: kaizhangNV Date: Wed, 18 Sep 2024 11:30:03 -0700 Subject: [PATCH] Merge branch 'master' into arial/initializer-list-visibility --- .gitignore | 5 + .../002-type-equality-constraints.md | 191 +++ docs/scripts/release-note.sh | 4 +- include/slang.h | 9 + .../slang-artifact-desc-util.cpp | 3 + source/compiler-core/slang-artifact.h | 1 + source/core/slang-type-text-util.cpp | 1 + source/slang-glslang/slang-glslang.cpp | 7 +- .../slang-record-replay/util/emum-to-string.h | 1 + source/slang/core.meta.slang | 2 +- source/slang/hlsl.meta.slang | 629 ++++++---- source/slang/slang-capabilities.capdef | 34 +- source/slang/slang-check-constraint.cpp | 4 +- source/slang/slang-check-impl.h | 2 +- source/slang/slang-check-overload.cpp | 82 ++ source/slang/slang-compiler.cpp | 1 + source/slang/slang-compiler.h | 15 +- source/slang/slang-doc-markdown-writer.cpp | 4 + source/slang/slang-emit-c-like.cpp | 104 +- source/slang/slang-emit-c-like.h | 19 +- source/slang/slang-emit-wgsl.cpp | 1023 +++++++++++++++++ source/slang/slang-emit-wgsl.h | 78 ++ source/slang/slang-emit.cpp | 46 +- source/slang/slang-ir-inst-defs.h | 5 + source/slang/slang-ir-insts.h | 16 + source/slang/slang-ir-link.cpp | 1 + .../slang-ir-lower-buffer-element-type.cpp | 19 +- source/slang/slang-ir-specialize.cpp | 121 +- .../slang-ir-use-uninitialized-values.cpp | 6 + source/slang/slang-ir-util.cpp | 4 + source/slang/slang-ir-wgsl-legalize.cpp | 347 ++++++ source/slang/slang-ir-wgsl-legalize.h | 10 + source/slang/slang-profile.h | 1 + source/slang/slang-reflection-api.cpp | 69 +- source/slang/slang-type-layout.cpp | 5 + source/slang/slang.cpp | 104 +- test-record-replay.sh | 147 --- tests/bugs/gh-5026.slang | 23 + tests/bugs/overload-ambiguous.slang | 48 + ...ninitialized-struct-from-constructor.slang | 24 + .../generic-return-type-requirement.slang | 39 + .../nested-gen-value-param-inference-2.slang | 41 + .../nested-gen-value-param-inference.slang | 32 + .../default-construct-conformance.slang | 8 +- .../overloaded-subscript.slang | 48 + tests/wgsl/math.slang | 279 +++++ tools/slang-test/slang-test-main.cpp | 5 + tools/slang-test/test-context.h | 3 +- .../unit-test-decl-tree-reflection.cpp | 67 +- 49 files changed, 3242 insertions(+), 495 deletions(-) create mode 100644 docs/proposals/002-type-equality-constraints.md create mode 100644 source/slang/slang-emit-wgsl.cpp create mode 100644 source/slang/slang-emit-wgsl.h create mode 100644 source/slang/slang-ir-wgsl-legalize.cpp create mode 100644 source/slang/slang-ir-wgsl-legalize.h delete mode 100755 test-record-replay.sh create mode 100644 tests/bugs/gh-5026.slang create mode 100644 tests/bugs/overload-ambiguous.slang create mode 100644 tests/diagnostics/uninitialized-struct-from-constructor.slang create mode 100644 tests/language-feature/generics/generic-return-type-requirement.slang create mode 100644 tests/language-feature/generics/nested-gen-value-param-inference-2.slang create mode 100644 tests/language-feature/generics/nested-gen-value-param-inference.slang create mode 100644 tests/language-feature/overloaded-subscript.slang create mode 100644 tests/wgsl/math.slang diff --git a/.gitignore b/.gitignore index 2c4d607779..941b0b745c 100644 --- a/.gitignore +++ b/.gitignore @@ -86,3 +86,8 @@ vkd3d-proton.cache vkd3d-proton.cache.write *_d3d11.log *_dxgi.log + +# Vim temporary files +*~ +.*.swp +.*.swo diff --git a/docs/proposals/002-type-equality-constraints.md b/docs/proposals/002-type-equality-constraints.md new file mode 100644 index 0000000000..44562075ee --- /dev/null +++ b/docs/proposals/002-type-equality-constraints.md @@ -0,0 +1,191 @@ +Allow Type Equality Constraints on Generics +=========================================== + +We propose to allow *type equality* constraints in `where` clauses. + +Status +------ + +In progress. + +Background +---------- + +As of proposal [001](001-where-clauses.md), Slang allows for generic declarations to include a *`where` clause* which enumerates constraints on the generic parameters that must be satisfied by any arguments provided to that generic: + + V findOrDefault( HashTable table, K key ) + where K : IHashable, + V : IDefaultInitializable + { ... } + +Currently, the language only accepts *conformance* constraints of the form `T : IFoo`, where `T` is one of the parameters of the generic, and `IFoo` is either an `interface` or a conjunction of interfaces, which indicate that the type `T` must conform to `IFoo`. + +This proposal is motivated by the observation that when an interface has associated types, there is currently no way for a programmer to introduce a generic that is only applicable when an associated type satisfies certain constriants. + +As an example, consider an interface for types that can be "packed" into a smaller representation for in-memory storage (instead of a default representation optimized for access from registers): + + interface IPackable + { + associatedtype Packed; + + init(Packed packed); + Packed pack(); + } + +Next, consider an hypothetical interface for types that can be deserialized from a stream: + + interface IDeserializable + { + init( InputStream stream ); + } + +Given these definitions, we might want to define a function that takes a packable type, and deserializes it from a stream: + + T deserializePackable( InputStream stream ) + where T : IPackable + { + return T( T.Packed(stream) ); + } + +As written, this function will fail to compile because the compiler cannot assume that `T.Packed` conforms to `IDeserializable`, in order to support initialization from a stream. + +A brute-force solution would be to add the `IDeserializable` constraint to the `IPackable.Packed` associated type, but doing so may not be consistent with the vision the designer of `IPackable` had in mind. Indeed, there is no reason to assume that `IPackable` and `IDeserializable` even have the same author, or are things that the programmer trying to write `deserializePackable` can change. + +It might seem that we could improve the situation by introducing another generic type parameter, so that we can explicitly constraint it to be deserializable: + + T deserializePackable( InputStream stream ) + where T : IPackable, + P : IDeserializable + { + return T( U(stream) ); + } + +This second attempt *also* fails to compile. +In this case, there is no way for the compiler to know that `T` can be initialized from a `P`, because it cannot intuit that `P` is meant to be `T.Packed`. + +Our two failed attempts can each be fixed by introducing two new kinds of constraints: + +* Conformance constraints on associated types: `T.A : IFoo` + +* Equality constraints on associated types: `T.A == X` + +Related Work +------------ + +Both Rust and Swift support additional kinds of constraints on generics, including the cases proposed here. +The syntax in those languages matches what we propose. + +Proposed Approach +----------------- + +In addition to conformance constraints on generic type parameters (`T : IFoo`), the compiler will also support constraints on associated types of those parameters (`T.A : IFoo`), and associated types of those associated types (`T.A.B : IFoo`), etc. + +In addition, the compiler will accept constraints that restrict an associated type (`T.A`, `T.A.B`, etc.) to be equal to some other type. +The other type may be a concrete type, another generic parameter, or another associated type. + +Detailed Explanation +-------------------- + +### Parser + +The parser already supports nearly arbitrary type exprssions on both sides of a conformance constraint, and then validates that the types used are allowed during semantic checking. +The only change needed at that level is to split `GenericTypeConstraintDecl` into two cases: one for conformance constraints, and another for equality constraints, and then to support constraints with `==` instead of `:`. + +### Semantic Checking + +During semantic checking, instead of checking that the left-hand type in a constraint is always one of the generic type parameters, we could instead check that the left-hand type expression is either a generic type parameter or `X.AssociatedType` where `X` would be a valid left-hand type. + +The right-hand type for conformance constraints should be checked the same as before. + +The right-hand type for an equality constraint should be allowed to be an arbitrary type expression that names a proper (and non-`interface`) type. + +One subtlety is that in a type expression like `T.A.B` where both `A` and `B` are associated types, it may be that the `B` member of `T.A` can only be looked up because of another constraint like `T.A : IFoo`. +When performing semantic checking of a constraint in a `where` clause, we need to decide which of the constraints may inform lookup when resolving a type expression like `X.A`. +Some options are: + +* We could consider only constraints that appear before the constraint that includes that type expression. In this case, a programmer must always introduce a constraint `X : IFoo` before a constraint that names `X.A`, if `A` is an associated type introduced by `IFoo`. + +* We could consider *all* of the constraints simultaneously (except, perhaps, the constraint that we are in the middle of checking). + +The latter option is more flexible, but may be (much) harder to implement in practice. +We propose that for now we use for first option, but remain open to implementing the more general case in the future. + +Given an equality constraint like `T.A.B == X`, semantic checking needs detect cases where an `X` is used and a `T.A.B` is expected, or vice versa. +These cases should introduce some kind of cast-like expression, which references the type equality witness as evidence that the cast is valid (and should, in theory, be a no-op). + +Semantic checking of equality constraints should identify contradictory sets of constraints. +Such contradictions can be simple to spot: + + interface IThing { associatedtype A; } + void f() + where T : IThing, + T.A == String, + T.A == Float, + { ... } + +but they can also be more complicated: + + void f() + where T : IThing, + U : IThing, + T.A == String, + U.A == Float, + T.A == U.A + { ... } + +In each case, an associated type is being constrained to be equal to two *different* concrete types. +The is no possible set of generic arguments that could satisfy these constraints, so declarations like these should be rejected. + +We propose that the simplest way to identify and diagnose contradictory constraints like this is during canonicalization, as described below. + +### IR + +At the IR level, a conformance constraint on an associated type is no different than any other conformance constraint: it lowers to an explicit generic parameter that will accept a witness table as an argument. + +The choice of how to represent equality constraints is more subtle. +One option is to lower an equality constraint to *nothing* at the IR level, under the assumption that the casts that reference these constraints should lower to nothing. +Doing so would introduce yet another case where the IR we generate doesn't "type-check." +The other option is to lower a type equality constraint to an explicit generic parameter which is then applied via an explicit op to convert between the associated type and its known concrete equivalent. +The representation of the witnesses required to provide *arguments* for such parameters is something that hasn't been fully explored, so for now we prpose to take the first (easier) option. + +### Canonicalization + +Adding new kinds of constraints affects *canonicalization*, which was discussed in proposal 0001. +Conformane constraints involving associated types should already be order-able according to the rules in that proposal, so we primarily need to concern ourselves with equality constraints. + +We propose the following approach: + +* Take all of the equality constraints that arise after any expansion steps +* Divide the types named on either side of any equality constraint into *equivalence classes*, where if `X == Y` is a constraint, then `X` and `Y` must in teh same equivalence class + * Each type in an equivalence class will either be an associated type of the form `T.A.B...Z`, derived from a generic type parameter, or a *independent* type, which here means anything other than those associated types. + * Because of the rules enforced during semantic checking, each equivalence class must have at least one associated type in it. + * Each equivalence class may have zero or more independent types in it. +* For each equivalence class with more than one independent type in it, diagnose an error; the application is attempting to constrain one or more associated types to be equal to multiple distinct types at once +* For each equivalence class with exactly one independent type in it, produce new constraints of the form `T.A.B...Z == C`, one for each associated type in the equivalence class, where `C` is the independent type +* For each equivalence class with zero independent types in it, pick the *minimal* associated type (according to the type ordering), and produce new constraints of the form `T.A... == U.B...` for each *other* associated type in the equivalence class, where `U.B...` is the minimal associated type. +* Sort the new constraints by the associated type on their left-hand side. + +Alternatives Considered +----------------------- + +The main alternative here would be to simply not have these kinds of constraints, and push programmers to use type parameters instead of associated types in cases where they want to be able to enforce constraints on those types. +E.g., the `IPackable` interface from earlier could be rewritten into this form: + + + interface IPackable + { + init(Packed packed); + Packed pack(); + } + +With this form for `IPackable`, it becomes possible to use additional type parameters to constraint the `Packed` type: + + T deserializePackable( InputStream stream ) + where T : IPackable, + P : IDeserializable + { + return T( U(stream) ); + } + +While this workaround may seem reasomable in an isolated example like this, there is a strong reason why languages like Slang choose to have both generic type parameters (which act as *inputs* to an abstraction) and associated types (which act as *outputs*). +We believe that associated types are an important feature, and that they justify the complexity of these new kinds of constraints. \ No newline at end of file diff --git a/docs/scripts/release-note.sh b/docs/scripts/release-note.sh index f1a93137ac..3def707ccc 100644 --- a/docs/scripts/release-note.sh +++ b/docs/scripts/release-note.sh @@ -16,7 +16,7 @@ verbose=true $verbose && echo "Reminder: PLEASE make sure your local repo is up-to-date before running the script." >&2 gh="" -for candidate in "$(which gh.exe)" "/mnt/c/Program Files/GitHub CLI/gh.exe" "/c/Program Files/GitHub CLI/gh.exe" +for candidate in "$(which gh.exe)" "/mnt/c/Program Files/GitHub CLI/gh.exe" "/c/Program Files/GitHub CLI/gh.exe" "/cygdrive/c/Program Files/GitHub CLI/gh.exe" do if [ -x "$candidate" ] then @@ -51,7 +51,7 @@ do # Get PR number from the git commit title pr="$(echo "$line" | grep '#[1-9][0-9][0-9][0-9][0-9]*' | sed 's|.* (\#\([1-9][0-9][0-9][0-9][0-9]*\))|\1|')" - [ "x$pr" = "x" ] && break + [ "x$pr" = "x" ] && continue # Check if the PR is marked as a breaking change if "$gh" issue view $pr --json labels | grep -q 'pr: breaking change' diff --git a/include/slang.h b/include/slang.h index 9755415b3f..3024aa8844 100644 --- a/include/slang.h +++ b/include/slang.h @@ -603,6 +603,7 @@ extern "C" SLANG_METAL_LIB, ///< Metal library SLANG_METAL_LIB_ASM, ///< Metal library assembly SLANG_HOST_SHARED_LIBRARY, ///< A shared library/Dll for host code (for hosting CPU/OS) + SLANG_WGSL, ///< WebGPU shading language SLANG_TARGET_COUNT_OF, }; @@ -636,6 +637,7 @@ extern "C" SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' - includes LLVM and Clang SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt SLANG_PASS_THROUGH_METAL, ///< Metal compiler + SLANG_PASS_THROUGH_WGSL, ///< WGSL compiler SLANG_PASS_THROUGH_COUNT_OF, }; @@ -735,6 +737,7 @@ extern "C" SLANG_SOURCE_LANGUAGE_CUDA, SLANG_SOURCE_LANGUAGE_SPIRV, SLANG_SOURCE_LANGUAGE_METAL, + SLANG_SOURCE_LANGUAGE_WGSL, SLANG_SOURCE_LANGUAGE_COUNT_OF, }; @@ -2587,6 +2590,7 @@ extern "C" SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func); SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func); SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic); + SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(SlangReflectionFunction* func, SlangInt argTypeCount, SlangReflectionType* const* argTypes); // Abstract Decl Reflection @@ -3585,6 +3589,11 @@ namespace slang { return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic); } + + FunctionReflection* specializeWithArgTypes(unsigned int argCount, TypeReflection* const* types) + { + return (FunctionReflection*)spReflectionFunction_specializeWithArgTypes((SlangReflectionFunction*)this, argCount, (SlangReflectionType* const*)types); + } }; struct GenericReflection diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index a4190992cf..9794cc90e9 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -197,6 +197,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactKind, SLANG_ARTIFACT_KIND, SLANG_ARTIFACT_KIND_E x(CUDA, Source) \ x(Metal, Source) \ x(Slang, Source) \ + x(WGSL, Source) \ x(KernelLike, Base) \ x(DXIL, KernelLike) \ x(DXBC, KernelLike) \ @@ -288,6 +289,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case SLANG_METAL: return Desc::make(Kind::Source, Payload::Metal, Style::Kernel, 0); case SLANG_METAL_LIB: return Desc::make(Kind::Executable, Payload::MetalAIR, Style::Kernel, 0); case SLANG_METAL_LIB_ASM: return Desc::make(Kind::Assembly, Payload::MetalAIR, Style::Kernel, 0); + case SLANG_WGSL: return Desc::make(Kind::Source, Payload::WGSL, Style::Kernel, 0); default: break; } @@ -330,6 +332,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case Payload::Cpp: return (desc.style == Style::Host) ? SLANG_HOST_CPP_SOURCE : SLANG_CPP_SOURCE; case Payload::CUDA: return SLANG_CUDA_SOURCE; case Payload::Metal: return SLANG_METAL; + case Payload::WGSL: return SLANG_WGSL; default: break; } break; diff --git a/source/compiler-core/slang-artifact.h b/source/compiler-core/slang-artifact.h index 400c85b2eb..6d65aafba4 100644 --- a/source/compiler-core/slang-artifact.h +++ b/source/compiler-core/slang-artifact.h @@ -143,6 +143,7 @@ enum class ArtifactPayload : uint8_t CUDA, ///< CUDA source Metal, ///< Metal source Slang, ///< Slang source + WGSL, ///< WGSL source KernelLike, ///< GPU Kernel like diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index 9fa91abf6c..9f9deb92c8 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -63,6 +63,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] = { SLANG_METAL, "metal", "metal", "Metal shader source" }, { SLANG_METAL_LIB, "metallib", "metallib", "Metal Library Bytecode" }, { SLANG_METAL_LIB_ASM, "metallib-asm" "metallib-asm", "Metal Library Bytecode assembly" }, + { SLANG_WGSL, "wgsl", "wgsl", "WebGPU shading language source" }, }; static const NamesDescriptionValue s_languageInfos[] = diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp index d29d22e894..2dd1b46ab6 100644 --- a/source/slang-glslang/slang-glslang.cpp +++ b/source/slang-glslang/slang-glslang.cpp @@ -3,16 +3,10 @@ #include "glslang/Public/ResourceLimits.h" -#include "StandAlone/Worklist.h" -#include "glslang/Include/ShHandle.h" #include "glslang/Public/ShaderLang.h" #include "SPIRV/GlslangToSpv.h" -#include "SPIRV/GLSL.std.450.h" -#include "SPIRV/doc.h" #include "SPIRV/disassemble.h" -#include "glslang/MachineIndependent/localintermediate.h" - #include "slang.h" #include "spirv-tools/optimizer.hpp" @@ -23,6 +17,7 @@ #endif #include +#include #include // This is a wrapper to allow us to run the `glslang` compiler diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h index 7a79525552..7226edc04c 100644 --- a/source/slang-record-replay/util/emum-to-string.h +++ b/source/slang-record-replay/util/emum-to-string.h @@ -34,6 +34,7 @@ namespace SlangRecord CASE(SLANG_METAL_LIB); CASE(SLANG_METAL_LIB_ASM); CASE(SLANG_HOST_SHARED_LIBRARY); + CASE(SLANG_WGSL); CASE(SLANG_TARGET_COUNT_OF); default: Slang::StringBuilder str; diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 4e85296664..03dda0fe5a 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -568,7 +568,7 @@ ${{{{ __intrinsic_op($(kIROp_Sub)) This sub(This other); __intrinsic_op($(kIROp_Mul)) This mul(This other); __intrinsic_op($(kIROp_Div)) This div(This other); - __intrinsic_op($(kIROp_FRem)) This mod(This other); + __intrinsic_op($(kIROp_IRem)) This mod(This other); __intrinsic_op($(kIROp_Neg)) This neg(); __intrinsic_op($(kIROp_Lsh)) This shl(int other); __intrinsic_op($(kIROp_Rsh)) This shr(int other); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 414cb2afe5..56a249be7f 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5445,7 +5445,7 @@ void abort(); __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T abs(T x) { __target_switch @@ -5458,6 +5458,7 @@ T abs(T x) case spirv: return spirv_asm { result:$$T = OpExtInst glsl450 SAbs $x }; + case wgsl: __intrinsic_asm "abs"; //default: // Note: this simple definition may not be appropriate for floating-point inputs // return x < 0 ? -x : x; @@ -5466,7 +5467,7 @@ T abs(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector abs(vector x) { __target_switch @@ -5477,6 +5478,7 @@ vector abs(vector x) case spirv: return spirv_asm { result:$$vector = OpExtInst glsl450 SAbs $x; }; + case wgsl: __intrinsic_asm "abs"; default: VECTOR_MAP_UNARY(T, N, abs, x); } @@ -5484,7 +5486,7 @@ vector abs(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix abs(matrix x) { __target_switch @@ -5497,7 +5499,7 @@ matrix abs(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T abs(T x) { __target_switch @@ -5510,12 +5512,13 @@ T abs(T x) case spirv: return spirv_asm { result:$$T = OpExtInst glsl450 FAbs $x; }; + case wgsl: __intrinsic_asm "abs"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector abs(vector x) { __target_switch @@ -5526,6 +5529,7 @@ vector abs(vector x) case spirv: return spirv_asm { result:$$vector = OpExtInst glsl450 FAbs $x; }; + case wgsl: __intrinsic_asm "abs"; default: VECTOR_MAP_UNARY(T, N, abs, x); } @@ -5533,7 +5537,7 @@ vector abs(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix abs(matrix x) { __target_switch @@ -5547,7 +5551,7 @@ matrix abs(matrix x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T fabs(T x) { __target_switch @@ -5561,7 +5565,7 @@ T fabs(T x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector fabs(vector x) { __target_switch @@ -5577,7 +5581,7 @@ vector fabs(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T acos(T x) { __target_switch @@ -5590,12 +5594,13 @@ T acos(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Acos $x }; + case wgsl: __intrinsic_asm "acos"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector acos(vector x) { __target_switch @@ -5606,6 +5611,7 @@ vector acos(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Acos $x }; + case wgsl: __intrinsic_asm "acos"; default: VECTOR_MAP_UNARY(T, N, acos, x); } @@ -5613,7 +5619,7 @@ vector acos(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix acos(matrix x) { __target_switch @@ -5629,7 +5635,7 @@ matrix acos(matrix x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T acosh(T x) { __target_switch @@ -5641,6 +5647,7 @@ T acosh(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Acosh $x }; + case wgsl: __intrinsic_asm "acosh"; default: return log(x + sqrt( x * x - T(1))); } @@ -5649,7 +5656,7 @@ T acosh(T x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector acosh(vector x) { __target_switch @@ -5659,6 +5666,7 @@ vector acosh(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Acosh $x }; + case wgsl: __intrinsic_asm "acosh"; default: VECTOR_MAP_UNARY(T, N, acosh, x); } @@ -5668,7 +5676,7 @@ vector acosh(vector x) // Test if all components are non-zero (HLSL SM 1.0) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(T x) { __target_switch @@ -5679,6 +5687,8 @@ bool all(T x) __intrinsic_asm "all"; case metal: __intrinsic_asm "all"; + case wgsl: + __intrinsic_asm "all"; case spirv: let zero = __default(); if (__isInt()) @@ -5700,7 +5710,7 @@ bool all(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(vector x) { if(N == 1) @@ -5737,6 +5747,8 @@ bool all(vector x) OpAll $$bool result %castResult }; } + case wgsl: + __intrinsic_asm "all"; default: bool result = true; for(int i = 0; i < N; ++i) @@ -5806,7 +5818,7 @@ int3 WorkgroupSize(); __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(T x) { __target_switch @@ -5817,6 +5829,8 @@ bool any(T x) __intrinsic_asm "any"; case metal: __intrinsic_asm "any"; + case wgsl: + __intrinsic_asm "any"; case spirv: let zero = __default(); if (__isInt()) @@ -5837,7 +5851,7 @@ bool any(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(vector x) { if(N == 1) @@ -5874,6 +5888,8 @@ bool any(vector x) OpAny $$bool result %castResult }; } + case wgsl: + __intrinsic_asm "any"; default: bool result = false; for(int i = 0; i < N; ++i) @@ -5936,7 +5952,7 @@ double2 asdouble(uint2 lowbits, uint2 highbits) // Reinterpret bits as a float (HLSL SM 4.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] float asfloat(int x) { __target_switch @@ -5949,11 +5965,12 @@ float asfloat(int x) case spirv: return spirv_asm { OpBitcast $$float result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; } } [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] float asfloat(uint x) { __target_switch @@ -5966,12 +5983,13 @@ float asfloat(uint x) case spirv: return spirv_asm { OpBitcast $$float result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] vector asfloat(vector< int, N> x) { __target_switch @@ -5982,6 +6000,7 @@ vector asfloat(vector< int, N> x) case spirv: return spirv_asm { OpBitcast $$vector result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; default: VECTOR_MAP_UNARY(float, N, asfloat, x); } @@ -5989,7 +6008,7 @@ vector asfloat(vector< int, N> x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] vector asfloat(vector x) { __target_switch @@ -6000,6 +6019,7 @@ vector asfloat(vector x) case spirv: return spirv_asm { OpBitcast $$vector result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; default: VECTOR_MAP_UNARY(float, N, asfloat, x); } @@ -6052,7 +6072,7 @@ matrix asfloat(matrix x) // Inverse sine (HLSL SM 1.0) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T asin(T x) { __target_switch @@ -6065,12 +6085,13 @@ T asin(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Asin $x }; + case wgsl: __intrinsic_asm "asin"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector asin(vector x) { __target_switch @@ -6081,6 +6102,7 @@ vector asin(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Asin $x }; + case wgsl: __intrinsic_asm "asin"; default: VECTOR_MAP_UNARY(T,N,asin,x); } @@ -6088,7 +6110,7 @@ vector asin(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix asin(matrix x) { __target_switch @@ -6104,7 +6126,7 @@ matrix asin(matrix x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T asinh(T x) { __target_switch @@ -6116,6 +6138,7 @@ T asinh(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Asinh $x }; + case wgsl: __intrinsic_asm "asinh"; default: return log(x + sqrt(x * x + T(1))); } @@ -6124,7 +6147,7 @@ T asinh(T x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector asinh(vector x) { __target_switch @@ -6134,6 +6157,7 @@ vector asinh(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Asinh $x }; + case wgsl: __intrinsic_asm "asinh"; default: VECTOR_MAP_UNARY(T, N, asinh, x); } @@ -6142,7 +6166,7 @@ vector asinh(vector x) // Reinterpret bits as an int (HLSL SM 4.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] int asint(float x) { __target_switch @@ -6152,6 +6176,7 @@ int asint(float x) case glsl: __intrinsic_asm "floatBitsToInt"; case hlsl: __intrinsic_asm "asint"; case metal: __intrinsic_asm "as_type<$TR>($0)"; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; case spirv: return spirv_asm { OpBitcast $$int result $x }; @@ -6159,7 +6184,7 @@ int asint(float x) } [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] int asint(uint x) { __target_switch @@ -6172,12 +6197,13 @@ int asint(uint x) case spirv: return spirv_asm { OpBitcast $$int result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] vector asint(vector x) { __target_switch @@ -6188,6 +6214,7 @@ vector asint(vector x) case spirv: return spirv_asm { OpBitcast $$vector result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; default: VECTOR_MAP_UNARY(int, N, asint, x); } @@ -6195,7 +6222,7 @@ vector asint(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] vector asint(vector x) { if(N == 1) @@ -6208,6 +6235,7 @@ vector asint(vector x) case spirv: return spirv_asm { OpBitcast $$vector result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; default: VECTOR_MAP_UNARY(int, N, asint, x); } @@ -6285,7 +6313,7 @@ void asuint(double value, out uint lowbits, out uint highbits) // Reinterpret bits as a uint (HLSL SM 4.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] uint asuint(float x) { __target_switch @@ -6295,6 +6323,7 @@ uint asuint(float x) case glsl: __intrinsic_asm "floatBitsToUint"; case hlsl: __intrinsic_asm "asuint"; case metal: __intrinsic_asm "as_type<$TR>($0)"; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; case spirv: return spirv_asm { OpBitcast $$uint result $x }; @@ -6302,7 +6331,7 @@ uint asuint(float x) } [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] uint asuint(int x) { __target_switch @@ -6315,12 +6344,13 @@ uint asuint(int x) case spirv: return spirv_asm { OpBitcast $$uint result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] vector asuint(vector x) { __target_switch @@ -6333,12 +6363,13 @@ vector asuint(vector x) }; default: VECTOR_MAP_UNARY(uint, N, asuint, x); + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] vector asuint(vector x) { if(N == 1) @@ -6351,6 +6382,7 @@ vector asuint(vector x) case spirv: return spirv_asm { OpBitcast $$vector result $x }; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; default: VECTOR_MAP_UNARY(uint, N, asuint, x); } @@ -6595,7 +6627,7 @@ matrix asfloat16(matrix va // Inverse tangent (HLSL SM 1.0) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T atan(T x) { __target_switch @@ -6608,12 +6640,13 @@ T atan(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Atan $x }; + case wgsl: __intrinsic_asm "atan"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector atan(vector x) { __target_switch @@ -6624,6 +6657,7 @@ vector atan(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Atan $x }; + case wgsl: __intrinsic_asm "atan"; default: VECTOR_MAP_UNARY(T, N, atan, x); } @@ -6631,7 +6665,7 @@ vector atan(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix atan(matrix x) { __target_switch @@ -6644,7 +6678,7 @@ matrix atan(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T atan2(T y, T x) { __target_switch @@ -6657,12 +6691,13 @@ T atan2(T y, T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Atan2 $y $x }; + case wgsl: __intrinsic_asm "atan2"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector atan2(vector y, vector x) { __target_switch @@ -6673,6 +6708,7 @@ vector atan2(vector y, vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Atan2 $y $x }; + case wgsl: __intrinsic_asm "atan2"; default: VECTOR_MAP_BINARY(T, N, atan2, y, x); } @@ -6680,7 +6716,7 @@ vector atan2(vector y, vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix atan2(matrix y, matrix x) { __target_switch @@ -6696,7 +6732,7 @@ matrix atan2(matrix y, matrix x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T atanh(T x) { __target_switch @@ -6708,6 +6744,7 @@ T atanh(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Atanh $x }; + case wgsl: __intrinsic_asm "atanh"; default: return T(0.5) * log((T(1) + x) / (T(1) - x)); } @@ -6716,7 +6753,7 @@ T atanh(T x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector atanh(vector x) { __target_switch @@ -6726,6 +6763,7 @@ vector atanh(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Atanh $x }; + case wgsl: __intrinsic_asm "atanh"; default: VECTOR_MAP_UNARY(T, N, atanh, x); } @@ -6734,7 +6772,7 @@ vector atanh(vector x) // Ceiling (HLSL SM 1.0) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T ceil(T x) { __target_switch @@ -6747,12 +6785,13 @@ T ceil(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Ceil $x }; + case wgsl: __intrinsic_asm "ceil"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector ceil(vector x) { __target_switch @@ -6763,6 +6802,7 @@ vector ceil(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Ceil $x }; + case wgsl: __intrinsic_asm "ceil"; default: VECTOR_MAP_UNARY(T, N, ceil, x); } @@ -6770,7 +6810,7 @@ vector ceil(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix ceil(matrix x) { __target_switch @@ -6877,7 +6917,7 @@ bool CheckAccessFullyMapped(out uint status) // Clamp (HLSL SM 1.0) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T clamp(T x, T minBound, T maxBound) { __target_switch @@ -6894,6 +6934,7 @@ T clamp(T x, T minBound, T maxBound) return spirv_asm { result:$$T = OpExtInst glsl450 UClamp $x $minBound $maxBound }; + case wgsl: __intrinsic_asm "clamp"; default: return min(max(x, minBound), maxBound); } @@ -6901,7 +6942,7 @@ T clamp(T x, T minBound, T maxBound) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector clamp(vector x, vector minBound, vector maxBound) { __target_switch @@ -6918,6 +6959,7 @@ vector clamp(vector x, vector minBound, vector maxBound) return spirv_asm { result:$$vector = OpExtInst glsl450 UClamp $x $minBound $maxBound }; + case wgsl: __intrinsic_asm "clamp"; default: return min(max(x, minBound), maxBound); } @@ -6925,7 +6967,7 @@ vector clamp(vector x, vector minBound, vector maxBound) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix clamp(matrix x, matrix minBound, matrix maxBound) { __target_switch @@ -6938,7 +6980,7 @@ matrix clamp(matrix x, matrix minBound, matrix maxBo __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T clamp(T x, T minBound, T maxBound) { __target_switch @@ -6949,6 +6991,7 @@ T clamp(T x, T minBound, T maxBound) case spirv: return spirv_asm { result:$$T = OpExtInst glsl450 FClamp $x $minBound $maxBound }; + case wgsl: __intrinsic_asm "clamp"; default: return min(max(x, minBound), maxBound); } @@ -6956,7 +6999,7 @@ T clamp(T x, T minBound, T maxBound) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector clamp(vector x, vector minBound, vector maxBound) { __target_switch @@ -6967,6 +7010,7 @@ vector clamp(vector x, vector minBound, vector maxBound) case spirv: return spirv_asm { result:$$vector = OpExtInst glsl450 FClamp $x $minBound $maxBound }; + case wgsl: __intrinsic_asm "clamp"; default: return min(max(x, minBound), maxBound); } @@ -6974,7 +7018,7 @@ vector clamp(vector x, vector minBound, vector maxBound) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix clamp(matrix x, matrix minBound, matrix maxBound) { __target_switch @@ -7025,7 +7069,7 @@ void clip(matrix x) // Cosine __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T cos(T x) { __target_switch @@ -7038,12 +7082,13 @@ T cos(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Cos $x }; + case wgsl: __intrinsic_asm "cos"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector cos(vector x) { __target_switch @@ -7054,6 +7099,7 @@ vector cos(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Cos $x }; + case wgsl: __intrinsic_asm "cos"; default: VECTOR_MAP_UNARY(T,N, cos, x); } @@ -7061,7 +7107,7 @@ vector cos(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix cos(matrix x) { __target_switch @@ -7075,7 +7121,7 @@ matrix cos(matrix x) // Hyperbolic cosine __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T cosh(T x) { __target_switch @@ -7088,12 +7134,13 @@ T cosh(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Cosh $x }; + case wgsl: __intrinsic_asm "cosh"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector cosh(vector x) { __target_switch @@ -7104,6 +7151,7 @@ vector cosh(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Cosh $x }; + case wgsl: __intrinsic_asm "cosh"; default: VECTOR_MAP_UNARY(T,N, cosh, x); } @@ -7111,7 +7159,7 @@ vector cosh(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix cosh(matrix x) { __target_switch @@ -7126,7 +7174,7 @@ matrix cosh(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T cospi(T x) { __target_switch @@ -7139,7 +7187,7 @@ T cospi(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector cospi(vector x) { __target_switch @@ -7154,7 +7202,7 @@ vector cospi(vector x) // Population count [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] uint countbits(uint value) { __target_switch @@ -7170,13 +7218,15 @@ uint countbits(uint value) __intrinsic_asm "$P_countbits($0)"; case spirv: return spirv_asm {OpBitCount $$uint result $value}; + case wgsl: + __intrinsic_asm "countOneBits"; } } __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector countbits(vector value) { __target_switch @@ -7189,6 +7239,8 @@ vector countbits(vector value) __intrinsic_asm "popcount"; case spirv: return spirv_asm {OpBitCount $$vector result $value}; + case wgsl: + __intrinsic_asm "countOneBits"; default: VECTOR_MAP_UNARY(uint, N, countbits, value); } @@ -7198,7 +7250,7 @@ vector countbits(vector value) // TODO: SPIRV does not support integer vectors. __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector cross(vector left, vector right) { __target_switch @@ -7209,6 +7261,7 @@ vector cross(vector left, vector right) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Cross $left $right }; + case wgsl: __intrinsic_asm "cross"; default: return vector( left.y * right.z - left.z * right.y, @@ -7219,7 +7272,7 @@ vector cross(vector left, vector right) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector cross(vector left, vector right) { __target_switch @@ -7229,6 +7282,7 @@ vector cross(vector left, vector right) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Cross $left $right }; + case wgsl: __intrinsic_asm "cross"; default: return vector( left.y * right.z - left.z * right.y, @@ -7239,12 +7293,13 @@ vector cross(vector left, vector right) // Convert encoded color [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] int4 D3DCOLORtoUBYTE4(float4 color) { __target_switch { case hlsl: __intrinsic_asm "D3DCOLORtoUBYTE4"; + case wgsl: __intrinsic_asm "bitcast(pack4x8unorm($0)).zyxw"; default: let scaled = color.zyxw * 255.001999f; return int4(scaled); @@ -7258,7 +7313,7 @@ for (auto xOrY : diffDimensions) { }}}} __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, fragmentprocessing)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, fragmentprocessing)] T dd$(xOrY)(T x) { __requireComputeDerivative(); @@ -7274,12 +7329,14 @@ T dd$(xOrY)(T x) __intrinsic_asm "dfd$(xOrY)"; case spirv: return spirv_asm {OpDPd$(xOrY) $$T result $x}; + case wgsl: + __intrinsic_asm "dpd$(xOrY)"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, fragmentprocessing)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, fragmentprocessing)] vector dd$(xOrY)(vector x) { __requireComputeDerivative(); @@ -7295,12 +7352,14 @@ vector dd$(xOrY)(vector x) __intrinsic_asm "dfd$(xOrY)"; case spirv: return spirv_asm {OpDPd$(xOrY) $$vector result $x}; + case wgsl: + __intrinsic_asm "dpd$(xOrY)"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, fragmentprocessing)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, fragmentprocessing)] matrix dd$(xOrY)(matrix x) { __requireComputeDerivative(); @@ -7412,7 +7471,7 @@ ${{{{ __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] T degrees(T x) { __target_switch @@ -7422,6 +7481,7 @@ T degrees(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Degrees $x }; + case wgsl: __intrinsic_asm "degrees"; default: return x * (T(180) / T.getPi()); } @@ -7429,7 +7489,7 @@ T degrees(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] vector degrees(vector x) { __target_switch @@ -7439,6 +7499,7 @@ vector degrees(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Degrees $x }; + case wgsl: __intrinsic_asm "degrees"; default: VECTOR_MAP_UNARY(T, N, degrees, x); } @@ -7446,7 +7507,7 @@ vector degrees(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] matrix degrees(matrix x) { __target_switch @@ -7462,7 +7523,7 @@ matrix degrees(matrix x) __generic [__readNone] [PreferCheckpoint] -[require(glsl_hlsl_metal_spirv)] +[require(glsl_hlsl_metal_spirv_wgsl)] T determinant(matrix m) { __target_switch @@ -7473,6 +7534,7 @@ T determinant(matrix m) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Determinant $m }; + case wgsl: __intrinsic_asm "determinant"; } } @@ -7515,7 +7577,7 @@ void DeviceMemoryBarrierWithGroupSync() __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T distance(vector x, vector y) { __target_switch @@ -7526,6 +7588,7 @@ T distance(vector x, vector y) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Distance $x $y }; + case wgsl: __intrinsic_asm "distance"; default: return length(x - y); } @@ -7533,7 +7596,7 @@ T distance(vector x, vector y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T distance(T x, T y) { __target_switch @@ -7542,6 +7605,7 @@ T distance(T x, T y) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Distance $x $y }; + case wgsl: __intrinsic_asm "distance"; default: return length(x - y); } @@ -7609,13 +7673,14 @@ vector divide(vector x, vector y) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T dot(T x, T y) { __target_switch { case glsl: __intrinsic_asm "dot"; case hlsl: __intrinsic_asm "dot"; + case wgsl: __intrinsic_asm "dot"; default: return x * y; } @@ -7623,7 +7688,7 @@ T dot(T x, T y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T dot(vector x, vector y) { __target_switch @@ -7634,6 +7699,7 @@ T dot(vector x, vector y) case spirv: return spirv_asm { OpDot $$T result $x $y }; + case wgsl: __intrinsic_asm "dot"; default: T result = T(0); for(int i = 0; i < N; ++i) @@ -7644,12 +7710,13 @@ T dot(vector x, vector y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T dot(vector x, vector y) { __target_switch { case hlsl: __intrinsic_asm "dot"; + case wgsl: __intrinsic_asm "dot"; default: T result = T(0); for(int i = 0; i < N; ++i) @@ -7834,7 +7901,7 @@ matrix EvaluateAttributeSnapped(matrix x, int2 offset) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T exp(T x) { __target_switch @@ -7847,12 +7914,13 @@ T exp(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Exp $x }; + case wgsl: __intrinsic_asm "exp"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector exp(vector x) { __target_switch @@ -7863,6 +7931,7 @@ vector exp(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Exp $x }; + case wgsl: __intrinsic_asm "exp"; default: VECTOR_MAP_UNARY(T, N, exp, x); } @@ -7870,7 +7939,7 @@ vector exp(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix exp(matrix x) { __target_switch @@ -7885,7 +7954,7 @@ matrix exp(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T exp2(T x) { __target_switch @@ -7911,13 +7980,15 @@ T exp2(T x) __intrinsic_asm "$P_exp2($0)"; case cuda: __intrinsic_asm "$P_exp2($0)"; + case wgsl: + __intrinsic_asm "exp2"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector exp2(vector x) { __target_switch @@ -7929,6 +8000,7 @@ vector exp2(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Exp2 $x }; + case wgsl: __intrinsic_asm "exp2"; default: VECTOR_MAP_UNARY(T, N, exp2, x); } @@ -7936,7 +8008,7 @@ vector exp2(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix exp2(matrix x) { __target_switch @@ -7951,7 +8023,7 @@ matrix exp2(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T exp10(T x) { __target_switch @@ -7965,7 +8037,7 @@ T exp10(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector exp10(vector x) { __target_switch @@ -7982,7 +8054,7 @@ vector exp10(vector x) __glsl_version(420) __cuda_sm_version(6.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] float f16tof32(uint value) { __target_switch @@ -8000,12 +8072,13 @@ float f16tof32(uint value) result:$$float = OpFConvert %half }; } + case wgsl: __intrinsic_asm "unpack2x16float($0).x"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector f16tof32(vector value) { __target_switch @@ -8030,7 +8103,7 @@ vector f16tof32(vector value) __glsl_version(420) __cuda_sm_version(6.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] uint f32tof16(float value) { __target_switch @@ -8048,12 +8121,13 @@ uint f32tof16(float value) result:$$uint = OpUConvert %lowBits }; } + case wgsl: __intrinsic_asm "pack2x16float(vec2f($0,0.0))"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector f32tof16(vector value) { __target_switch @@ -8079,7 +8153,7 @@ vector f32tof16(vector value) __glsl_version(420) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] float f16tof32(float16_t value) { __target_switch @@ -8095,12 +8169,13 @@ float f16tof32(float16_t value) result:$$float = OpFConvert $value }; } + case wgsl: __intrinsic_asm "f32($0)"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector f16tof32(vector value) { __target_switch @@ -8119,7 +8194,7 @@ vector f16tof32(vector value) // Convert to float16_t __glsl_version(420) [__readNone] -[require(cuda_glsl_metal_spirv, shader5_sm_5_0)] +[require(cuda_glsl_metal_spirv_wgsl, shader5_sm_5_0)] float16_t f32tof16_(float value) { __target_switch @@ -8130,12 +8205,13 @@ float16_t f32tof16_(float value) case spirv: return spirv_asm { OpFConvert $$float16_t result $value }; + case wgsl: __intrinsic_asm "f16($0)"; } } __generic [__readNone] -[require(cuda_glsl_metal_spirv, shader5_sm_5_0)] +[require(cuda_glsl_metal_spirv_wgsl, shader5_sm_5_0)] vector f32tof16_(vector value) { __target_switch @@ -8155,7 +8231,7 @@ vector f32tof16_(vector value) // Flip surface normal to face forward, if needed __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector faceforward(vector n, vector i, vector ng) { __target_switch @@ -8166,6 +8242,7 @@ vector faceforward(vector n, vector i, vector ng) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 FaceForward $n $i $ng }; + case wgsl: __intrinsic_asm "faceForward"; default: return dot(ng, i) < T(0.0f) ? n : -n; } @@ -8173,7 +8250,7 @@ vector faceforward(vector n, vector i, vector ng) // Find first set bit starting at high bit and working down [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] int firstbithigh(int value) { __target_switch @@ -8186,12 +8263,13 @@ int firstbithigh(int value) case spirv: return spirv_asm { OpExtInst $$int result glsl450 FindSMsb $value }; + case wgsl: __intrinsic_asm "firstLeadingBit"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector firstbithigh(vector value) { __target_switch @@ -8202,13 +8280,14 @@ vector firstbithigh(vector value) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 FindSMsb $value }; + case wgsl: __intrinsic_asm "firstLeadingBit"; default: VECTOR_MAP_UNARY(int, N, firstbithigh, value); } } [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] uint firstbithigh(uint value) { __target_switch @@ -8221,12 +8300,13 @@ uint firstbithigh(uint value) case spirv: return spirv_asm { OpExtInst $$uint result glsl450 FindUMsb $value }; + case wgsl: __intrinsic_asm "firstLeadingBit"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector firstbithigh(vector value) { __target_switch @@ -8237,6 +8317,7 @@ vector firstbithigh(vector value) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 FindUMsb $value }; + case wgsl: __intrinsic_asm "firstLeadingBit"; default: VECTOR_MAP_UNARY(uint, N, firstbithigh, value); } @@ -8244,7 +8325,7 @@ vector firstbithigh(vector value) // Find first set bit starting at low bit and working up [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] int firstbitlow(int value) { __target_switch @@ -8257,12 +8338,13 @@ int firstbitlow(int value) case spirv: return spirv_asm { OpExtInst $$int result glsl450 FindILsb $value }; + case wgsl: __intrinsic_asm "firstTrailingBit"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector firstbitlow(vector value) { __target_switch @@ -8273,13 +8355,14 @@ vector firstbitlow(vector value) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 FindILsb $value }; + case wgsl: __intrinsic_asm "firstTrailingBit"; default: VECTOR_MAP_UNARY(int, N, firstbitlow, value); } } [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] uint firstbitlow(uint value) { __target_switch @@ -8292,12 +8375,13 @@ uint firstbitlow(uint value) case spirv: return spirv_asm { OpExtInst $$uint result glsl450 FindILsb $value }; + case wgsl: __intrinsic_asm "firstTrailingBit"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector firstbitlow(vector value) { __target_switch @@ -8308,6 +8392,7 @@ vector firstbitlow(vector value) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 FindILsb $value }; + case wgsl: __intrinsic_asm "firstTrailingBit"; default: VECTOR_MAP_UNARY(uint, N, firstbitlow, value); } @@ -8317,7 +8402,7 @@ vector firstbitlow(vector value) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T floor(T x) { __target_switch @@ -8330,12 +8415,13 @@ T floor(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Floor $x }; + case wgsl: __intrinsic_asm "floor"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector floor(vector x) { __target_switch @@ -8346,6 +8432,7 @@ vector floor(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Floor $x }; + case wgsl: __intrinsic_asm "floor"; default: VECTOR_MAP_UNARY(T, N, floor, x); } @@ -8353,7 +8440,7 @@ vector floor(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix floor(matrix x) { __target_switch @@ -8367,7 +8454,7 @@ matrix floor(matrix x) // Fused multiply-add __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] T fma(T a, T b, T c) { __target_switch @@ -8384,6 +8471,7 @@ T fma(T a, T b, T c) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Fma $a $b $c }; + case wgsl: __intrinsic_asm "fma"; default: return a*b + c; } @@ -8391,7 +8479,7 @@ T fma(T a, T b, T c) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector fma(vector a, vector b, vector c) { __target_switch @@ -8402,6 +8490,7 @@ vector fma(vector a, vector b, vector c) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Fma $a $b $c }; + case wgsl: __intrinsic_asm "fma"; default: VECTOR_MAP_TRINARY(T, N, fma, a, b, c); } @@ -8409,7 +8498,7 @@ vector fma(vector a, vector b, vector c) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] matrix fma(matrix a, matrix b, matrix c) { __target_switch @@ -8424,7 +8513,7 @@ matrix fma(matrix a, matrix b, matrix c) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T fmod(T x, T y) { // In HLSL, `fmod` returns a remainder. @@ -8489,13 +8578,15 @@ T fmod(T x, T y) { result:$$T = OpFRem $x $y }; + case wgsl: + __intrinsic_asm "(($0) % ($1))"; } } __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector fmod(vector x, vector y) { __target_switch @@ -8512,7 +8603,7 @@ vector fmod(vector x, vector y) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix fmod(matrix x, matrix y) { __target_switch @@ -8526,7 +8617,7 @@ matrix fmod(matrix x, matrix y) // Fractional part __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T frac(T x) { __target_switch @@ -8539,12 +8630,13 @@ T frac(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Fract $x }; + case wgsl: __intrinsic_asm "fract"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector frac(vector x) { __target_switch @@ -8555,6 +8647,7 @@ vector frac(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Fract $x }; + case wgsl: __intrinsic_asm "fract"; default: VECTOR_MAP_UNARY(T, N, frac, x); } @@ -8570,7 +8663,7 @@ matrix frac(matrix x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T fract(T x) { return frac(x); @@ -8579,7 +8672,7 @@ T fract(T x) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector fract(vector x) { return frac(x); @@ -8589,7 +8682,7 @@ vector fract(vector x) // Split float into mantissa and exponent __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T frexp(T x, out int exp) { __target_switch @@ -8602,12 +8695,24 @@ T frexp(T x, out int exp) case spirv: return spirv_asm { result:$$T = OpExtInst glsl450 Frexp $x &exp }; + case wgsl: + T fract; + __wgsl_frexp(x, fract, exp); + return fract; } } +__generic +[__readNone] +[require(wgsl)] +void __wgsl_frexp(T x, out T fract, out int exp) +{ + __intrinsic_asm "{ var s = frexp($0); $1 = s.fract; $2 = s.exp; }"; +} + __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector frexp(vector x, out vector exp) { __target_switch @@ -8625,7 +8730,7 @@ vector frexp(vector x, out vector exp) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix frexp(matrix x, out matrix exp) { __target_switch @@ -8639,7 +8744,7 @@ matrix frexp(matrix x, out matrix exp) // Texture filter width __generic [__readNone] -[require(glsl_hlsl_metal_spirv, fragmentprocessing)] +[require(glsl_hlsl_metal_spirv_wgsl, fragmentprocessing)] T fwidth(T x) { __requireComputeDerivative(); @@ -8656,12 +8761,14 @@ T fwidth(T x) { OpFwidth $$T result $x; }; + case wgsl: + __intrinsic_asm "fwidth($0)"; } } __generic [__readNone] -[require(glsl_hlsl_spirv, fragmentprocessing)] +[require(glsl_hlsl_spirv_wgsl, fragmentprocessing)] vector fwidth(vector x) { __requireComputeDerivative(); @@ -8676,6 +8783,8 @@ vector fwidth(vector x) { OpFwidth $$vector result $x; }; + case wgsl: + __intrinsic_asm "fwidth($0)"; } } @@ -9986,12 +10095,13 @@ matrix isnan(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T ldexp(T x, T exp) { __target_switch { case hlsl: __intrinsic_asm "ldexp"; + case wgsl: __intrinsic_asm "ldexp"; default: return x * exp2(exp); } @@ -9999,12 +10109,13 @@ T ldexp(T x, T exp) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector ldexp(vector x, vector exp) { __target_switch { case hlsl: __intrinsic_asm "ldexp"; + case wgsl: __intrinsic_asm "ldexp"; default: return x * exp2(exp); } @@ -10012,7 +10123,7 @@ vector ldexp(vector x, vector exp) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix ldexp(matrix x, matrix exp) { __target_switch @@ -10025,7 +10136,7 @@ matrix ldexp(matrix x, matrix exp) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T ldexp(T x, E exp) { __target_switch @@ -10036,6 +10147,7 @@ T ldexp(T x, E exp) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Ldexp $x $exp }; + case wgsl: __intrinsic_asm "ldexp"; default: return ldexp(x, __realCast(exp)); } @@ -10043,7 +10155,7 @@ T ldexp(T x, E exp) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector ldexp(vector x, vector exp) { __target_switch @@ -10054,6 +10166,7 @@ vector ldexp(vector x, vector exp) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Ldexp $x $exp }; + case wgsl: __intrinsic_asm "ldexp"; default: vector temp; [ForceUnroll] @@ -10067,7 +10180,7 @@ vector ldexp(vector x, vector exp) // Vector length __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T length(vector x) { __target_switch @@ -10078,6 +10191,7 @@ T length(vector x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Length $x }; + case wgsl: __intrinsic_asm "length"; default: return sqrt(dot(x, x)); } @@ -10085,7 +10199,7 @@ T length(vector x) // Scalar float length __generic -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T length(T x) { __target_switch @@ -10094,6 +10208,7 @@ T length(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Length $x }; + case wgsl: __intrinsic_asm "length"; default: return abs(x); } @@ -10168,7 +10283,7 @@ float4 lit(float n_dot_l, float n_dot_h, float m) // Base-e logarithm __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T log(T x) { __target_switch @@ -10181,12 +10296,13 @@ T log(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Log $x }; + case wgsl: __intrinsic_asm "log"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector log(vector x) { __target_switch @@ -10197,6 +10313,7 @@ vector log(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Log $x }; + case wgsl: __intrinsic_asm "log"; default: VECTOR_MAP_UNARY(T, N, log, x); } @@ -10204,7 +10321,7 @@ vector log(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix log(matrix x) { __target_switch @@ -10278,7 +10395,7 @@ matrix log10(matrix x) // Base-2 logarithm __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T log2(T x) { __target_switch @@ -10291,12 +10408,13 @@ T log2(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Log2 $x }; + case wgsl: __intrinsic_asm "log2"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector log2(vector x) { __target_switch @@ -10307,6 +10425,7 @@ vector log2(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Log2 $x }; + case wgsl: __intrinsic_asm "log2"; default: VECTOR_MAP_UNARY(T, N, log2, x); } @@ -10314,7 +10433,7 @@ vector log2(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix log2(matrix x) { __target_switch @@ -10427,7 +10546,7 @@ matrix mad(matrix mvalue, matrix avalue, matrix [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T max(T x, T y) { // Note: a stdlib implementation of `max` (or `min`) will require splitting @@ -10457,12 +10576,13 @@ T max(T x, T y) }; } } + case wgsl: __intrinsic_asm "max"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector max(vector x, vector y) { __target_switch @@ -10485,6 +10605,7 @@ vector max(vector x, vector y) }; } } + case wgsl: __intrinsic_asm "max"; default: VECTOR_MAP_BINARY(T, N, max, x, y); } @@ -10492,7 +10613,7 @@ vector max(vector x, vector y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix max(matrix x, matrix y) { __target_switch @@ -10505,7 +10626,7 @@ matrix max(matrix x, matrix y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T max(T x, T y) { __target_switch @@ -10518,12 +10639,13 @@ T max(T x, T y) case spirv: return spirv_asm { result:$$T = OpExtInst glsl450 FMax $x $y }; + case wgsl: __intrinsic_asm "max"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector max(vector x, vector y) { __target_switch @@ -10534,6 +10656,7 @@ vector max(vector x, vector y) case spirv: return spirv_asm { result:$$vector = OpExtInst glsl450 FMax $x $y }; + case wgsl: __intrinsic_asm "max"; default: VECTOR_MAP_BINARY(T, N, max, x, y); } @@ -10541,7 +10664,7 @@ vector max(vector x, vector y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix max(matrix x, matrix y) { __target_switch @@ -10656,7 +10779,7 @@ vector fmax3(vector x, vector y, vector z) // minimum __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T min(T x, T y) { __target_switch @@ -10679,12 +10802,13 @@ T min(T x, T y) result:$$T = OpExtInst glsl450 UMin $x $y }; } + case wgsl: __intrinsic_asm "min"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector min(vector x, vector y) { __target_switch @@ -10703,6 +10827,7 @@ vector min(vector x, vector y) result:$$vector = OpExtInst glsl450 UMin $x $y }; } + case wgsl: __intrinsic_asm "min"; default: VECTOR_MAP_BINARY(T, N, min, x, y); } @@ -10710,7 +10835,7 @@ vector min(vector x, vector y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix min(matrix x, matrix y) { __target_switch @@ -10723,7 +10848,7 @@ matrix min(matrix x, matrix y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T min(T x, T y) { __target_switch @@ -10736,12 +10861,13 @@ T min(T x, T y) case spirv: return spirv_asm { result:$$T = OpExtInst glsl450 FMin $x $y }; + case wgsl: __intrinsic_asm "min"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector min(vector x, vector y) { __target_switch @@ -10752,6 +10878,7 @@ vector min(vector x, vector y) case spirv: return spirv_asm { result:$$vector = OpExtInst glsl450 FMin $x $y }; + case wgsl: __intrinsic_asm "min"; default: VECTOR_MAP_BINARY(T, N, min, x, y); } @@ -10759,7 +10886,7 @@ vector min(vector x, vector y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix min(matrix x, matrix y) { __target_switch @@ -10966,7 +11093,7 @@ vector fmedian3(vector x, vector y, vector z) // split into integer and fractional parts (both with same sign) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T modf(T x, out T ip) { __target_switch @@ -10979,12 +11106,24 @@ T modf(T x, out T ip) case spirv: return spirv_asm { result:$$T = OpExtInst glsl450 Modf $x &ip }; + case wgsl: + T fract; + __wgsl_modf(x, fract, ip); + return fract; } } +__generic +[__readNone] +[require(wgsl)] +void __wgsl_modf(T x, out T fract, out T whole) +{ + __intrinsic_asm "{ var s = modf($0); $1 = s.fract; $2 = s.whole; }"; +} + __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector modf(vector x, out vector ip) { __target_switch @@ -11002,7 +11141,7 @@ vector modf(vector x, out vector ip) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix modf(matrix x, out matrix ip) { __target_switch @@ -11075,7 +11214,7 @@ matrix mul(T x, matrix y); // vector-vector (dot product) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T mul(vector x, vector y) { __target_switch @@ -11083,13 +11222,14 @@ T mul(vector x, vector y) case glsl: __intrinsic_asm "dot"; case metal: __intrinsic_asm "dot"; case hlsl: __intrinsic_asm "mul"; + case wgsl: __intrinsic_asm "dot"; default: return dot(x, y); } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T mul(vector x, vector y) { __target_switch @@ -11103,7 +11243,7 @@ T mul(vector x, vector y) // vector-matrix __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector mul(vector left, matrix right) { __target_switch @@ -11114,6 +11254,7 @@ vector mul(vector left, matrix right) case spirv: return spirv_asm { OpMatrixTimesVector $$vector result $right $left }; + case wgsl: __intrinsic_asm "($1 * $0)"; default: vector result; for( int j = 0; j < M; ++j ) @@ -11130,7 +11271,7 @@ vector mul(vector left, matrix right) } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector mul(vector left, matrix right) { __target_switch @@ -11138,6 +11279,7 @@ vector mul(vector left, matrix right) case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; + case wgsl: __intrinsic_asm "($1 * $0)"; default: vector result; for( int j = 0; j < M; ++j ) @@ -11154,7 +11296,7 @@ vector mul(vector left, matrix right) } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector mul(vector left, matrix right) { __target_switch @@ -11162,6 +11304,7 @@ vector mul(vector left, matrix right) case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; + case wgsl: __intrinsic_asm "($1 * $0)"; default: vector result; for( int j = 0; j < M; ++j ) @@ -11180,7 +11323,7 @@ vector mul(vector left, matrix right) // matrix-vector __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector mul(matrix left, vector right) { __target_switch @@ -11191,6 +11334,7 @@ vector mul(matrix left, vector right) case spirv: return spirv_asm { OpVectorTimesMatrix $$vector result $right $left }; + case wgsl: __intrinsic_asm "($1 * $0)"; default: vector result; for( int i = 0; i < N; ++i ) @@ -11207,7 +11351,7 @@ vector mul(matrix left, vector right) } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector mul(matrix left, vector right) { __target_switch @@ -11215,6 +11359,7 @@ vector mul(matrix left, vector right) case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; + case wgsl: __intrinsic_asm "($1 * $0)"; default: vector result; for( int i = 0; i < N; ++i ) @@ -11232,7 +11377,7 @@ vector mul(matrix left, vector right) __generic [__readNone] [OverloadRank(-1)] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector mul(matrix left, vector right) { __target_switch @@ -11240,6 +11385,7 @@ vector mul(matrix left, vector right) case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; + case wgsl: __intrinsic_asm "($1 * $0)"; default: vector result; for( int i = 0; i < N; ++i ) @@ -11258,7 +11404,7 @@ vector mul(matrix left, vector right) // matrix-matrix __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix mul(matrix left, matrix right) { __target_switch @@ -11269,6 +11415,7 @@ matrix mul(matrix left, matrix right) case spirv: return spirv_asm { OpMatrixTimesMatrix $$matrix result $right $left }; + case wgsl: __intrinsic_asm "($1 * $0)"; default: matrix result; for( int r = 0; r < R; ++r) @@ -11286,7 +11433,7 @@ matrix mul(matrix left, matrix right) } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix mul(matrix left, matrix right) { __target_switch @@ -11294,6 +11441,7 @@ matrix mul(matrix left, matrix right) case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; + case wgsl: __intrinsic_asm "($1 * $0)"; default: matrix result; for( int r = 0; r < R; ++r) @@ -11311,7 +11459,7 @@ matrix mul(matrix left, matrix right) } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix mul(matrix left, matrix right) { __target_switch @@ -11319,6 +11467,7 @@ matrix mul(matrix left, matrix right) case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; + case wgsl: __intrinsic_asm "($1 * $0)"; default: matrix result; for( int r = 0; r < R; ++r) @@ -11442,7 +11591,7 @@ T NonUniformResourceIndex(T value) { return value; } // Normalize a vector __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector normalize(vector x) { __target_switch @@ -11453,6 +11602,7 @@ vector normalize(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Normalize $x }; + case wgsl: __intrinsic_asm "normalize"; default: return x / length(x); } @@ -11460,7 +11610,7 @@ vector normalize(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T normalize(T x) { __target_switch @@ -11471,6 +11621,7 @@ T normalize(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Normalize $x }; + case wgsl: __intrinsic_asm "normalize"; default: return x / length(x); } @@ -11479,7 +11630,7 @@ T normalize(T x) // Raise to a power __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T pow(T x, T y) { __target_switch @@ -11492,12 +11643,13 @@ T pow(T x, T y) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Pow $x $y }; + case wgsl: __intrinsic_asm "pow"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector pow(vector x, vector y) { __target_switch @@ -11508,6 +11660,7 @@ vector pow(vector x, vector y) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Pow $x $y }; + case wgsl: __intrinsic_asm "pow"; default: VECTOR_MAP_BINARY(T, N, pow, x, y); } @@ -11515,7 +11668,7 @@ vector pow(vector x, vector y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix pow(matrix x, matrix y) { __target_switch @@ -11528,7 +11681,7 @@ matrix pow(matrix x, matrix y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T powr(T x, T y) { __target_switch @@ -11541,7 +11694,7 @@ T powr(T x, T y) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector powr(vector x, vector y) { __target_switch @@ -11687,7 +11840,7 @@ void ProcessTriTessFactorsMin( // Degrees to radians __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T radians(T x) { __target_switch @@ -11697,6 +11850,7 @@ T radians(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Radians $x }; + case wgsl: __intrinsic_asm "radians"; default: return x * (T.getPi() / T(180.0f)); } @@ -11704,7 +11858,7 @@ T radians(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector radians(vector x) { __target_switch @@ -11714,6 +11868,7 @@ vector radians(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Radians $x }; + case wgsl: __intrinsic_asm "radians"; default: return x * (T.getPi() / T(180.0f)); } @@ -11721,7 +11876,7 @@ vector radians(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix radians(matrix x) { __target_switch @@ -11735,7 +11890,7 @@ matrix radians(matrix x) // Approximate reciprocal __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T rcp(T x) { __target_switch @@ -11748,7 +11903,7 @@ T rcp(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector rcp(vector x) { __target_switch @@ -11756,6 +11911,7 @@ vector rcp(vector x) case hlsl: __intrinsic_asm "rcp"; case glsl: case spirv: + case wgsl: return T(1.0) / x; default: VECTOR_MAP_UNARY(T, N, rcp, x); @@ -11764,7 +11920,7 @@ vector rcp(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix rcp(matrix x) { __target_switch @@ -11778,7 +11934,7 @@ matrix rcp(matrix x) // Reflect incident vector across plane with given normal __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T reflect(T i, T n) { __target_switch @@ -11789,6 +11945,7 @@ T reflect(T i, T n) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Reflect $i $n }; + case wgsl: __intrinsic_asm "reflect"; default: return i - T(2) * dot(n,i) * n; } @@ -11796,7 +11953,7 @@ T reflect(T i, T n) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector reflect(vector i, vector n) { __target_switch @@ -11807,6 +11964,7 @@ vector reflect(vector i, vector n) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Reflect $i $n }; + case wgsl: __intrinsic_asm "reflect"; default: return i - T(2) * dot(n,i) * n; } @@ -11815,7 +11973,7 @@ vector reflect(vector i, vector n) // Refract incident vector given surface normal and index of refraction __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector refract(vector i, vector n, T eta) { __target_switch @@ -11826,6 +11984,7 @@ vector refract(vector i, vector n, T eta) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Refract $i $n $eta }; + case wgsl: __intrinsic_asm "refract"; default: let dotNI = dot(n,i); let k = T(1) - eta*eta*(T(1) - dotNI * dotNI); @@ -11836,7 +11995,7 @@ vector refract(vector i, vector n, T eta) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T refract(T i, T n, T eta) { __target_switch @@ -11847,6 +12006,7 @@ T refract(T i, T n, T eta) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Refract $i $n $eta }; + case wgsl: __intrinsic_asm "refract"; default: let dotNI = dot(n,i); let k = T(1) - eta*eta*(T(1) - dotNI * dotNI); @@ -11857,7 +12017,7 @@ T refract(T i, T n, T eta) // Reverse order of bits [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] uint reversebits(uint value) { __target_switch @@ -11873,12 +12033,13 @@ uint reversebits(uint value) __intrinsic_asm "reverse_bits"; case spirv: return spirv_asm {OpBitReverse $$uint result $value}; + case wgsl: __intrinsic_asm "reverseBits"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_5_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector reversebits(vector value) { __target_switch @@ -11891,6 +12052,7 @@ vector reversebits(vector value) __intrinsic_asm "reverse_bits"; case spirv: return spirv_asm {OpBitReverse $$vector result $value}; + case wgsl: __intrinsic_asm "reverseBits"; } } @@ -11947,7 +12109,7 @@ vector rint(vector x) // Round-to-nearest __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T round(T x) { __target_switch @@ -11960,12 +12122,13 @@ T round(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Round $x }; + case wgsl: __intrinsic_asm "round"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector round(vector x) { __target_switch @@ -11976,6 +12139,7 @@ vector round(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Round $x }; + case wgsl: __intrinsic_asm "round"; default: VECTOR_MAP_UNARY(T, N, round, x); } @@ -11983,7 +12147,7 @@ vector round(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix round(matrix x) { __target_switch @@ -11997,7 +12161,7 @@ matrix round(matrix x) // Reciprocal of square root __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T rsqrt(T x) { __target_switch @@ -12017,7 +12181,7 @@ T rsqrt(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector rsqrt(vector x) { __target_switch @@ -12035,7 +12199,7 @@ vector rsqrt(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix rsqrt(matrix x) { __target_switch @@ -12050,13 +12214,14 @@ matrix rsqrt(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T saturate(T x) { __target_switch { case hlsl: __intrinsic_asm "saturate"; case metal: __intrinsic_asm "saturate"; + case wgsl: __intrinsic_asm "saturate"; default: return clamp(x, T(0), T(1)); } @@ -12064,13 +12229,14 @@ T saturate(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector saturate(vector x) { __target_switch { case hlsl: __intrinsic_asm "saturate"; case metal: __intrinsic_asm "saturate"; + case wgsl: __intrinsic_asm "saturate"; default: return clamp(x, vector(T(0)), @@ -12080,7 +12246,7 @@ vector saturate(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix saturate(matrix x) { __target_switch @@ -12121,6 +12287,7 @@ int sign(T x) }; else return __int_cast(spirv_asm {OpExtInst $$T result glsl450 SSign $x}); + case wgsl: __intrinsic_asm "sign"; } } @@ -12144,6 +12311,7 @@ vector sign(vector x) }; else return __int_cast(spirv_asm {OpExtInst $$vector result glsl450 SSign $x}); + case wgsl: __intrinsic_asm "sign"; default: VECTOR_MAP_UNARY(int, N, sign, x); } @@ -12151,7 +12319,7 @@ vector sign(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] matrix sign(matrix x) { __target_switch @@ -12166,7 +12334,7 @@ matrix sign(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T sin(T x) { __target_switch @@ -12179,12 +12347,13 @@ T sin(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Sin $x }; + case wgsl: __intrinsic_asm "sin"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector sin(vector x) { __target_switch @@ -12195,6 +12364,7 @@ vector sin(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Sin $x }; + case wgsl: __intrinsic_asm "sin"; default: VECTOR_MAP_UNARY(T, N, sin, x); } @@ -12202,7 +12372,7 @@ vector sin(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix sin(matrix x) { __target_switch @@ -12239,7 +12409,7 @@ vector __sincos_metal(vector x, out vector c) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] void sincos(T x, out T s, out T c) { __target_switch @@ -12259,7 +12429,7 @@ void sincos(T x, out T s, out T c) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] void sincos(vector x, out vector s, out vector c) { __target_switch @@ -12278,7 +12448,7 @@ void sincos(vector x, out vector s, out vector c) __generic [__readNone] [ForceInline] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] void sincos(matrix x, out matrix s, out matrix c) { __target_switch @@ -12293,7 +12463,7 @@ void sincos(matrix x, out matrix s, out matrix c) // Hyperbolic Sine __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T sinh(T x) { __target_switch @@ -12306,12 +12476,13 @@ T sinh(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Sinh $x }; + case wgsl: __intrinsic_asm "sinh"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector sinh(vector x) { __target_switch @@ -12322,6 +12493,7 @@ vector sinh(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Sinh $x }; + case wgsl: __intrinsic_asm "sinh"; default: VECTOR_MAP_UNARY(T, N, sinh, x); } @@ -12329,7 +12501,7 @@ vector sinh(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix sinh(matrix x) { __target_switch @@ -12344,7 +12516,7 @@ matrix sinh(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T sinpi(T x) { __target_switch @@ -12357,7 +12529,7 @@ T sinpi(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector sinpi(vector x) { __target_switch @@ -12372,7 +12544,7 @@ vector sinpi(vector x) // Smooth step (Hermite interpolation) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T smoothstep(T min, T max, T x) { __target_switch @@ -12383,6 +12555,7 @@ T smoothstep(T min, T max, T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 SmoothStep $min $max $x }; + case wgsl: __intrinsic_asm "smoothstep"; default: let t = saturate((x - min) / (max - min)); return t * t * (T(3.0f) - (t + t)); @@ -12391,7 +12564,7 @@ T smoothstep(T min, T max, T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector smoothstep(vector min, vector max, vector x) { __target_switch @@ -12402,6 +12575,7 @@ vector smoothstep(vector min, vector max, vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 SmoothStep $min $max $x }; + case wgsl: __intrinsic_asm "smoothstep"; default: VECTOR_MAP_TRINARY(T, N, smoothstep, min, max, x); } @@ -12409,7 +12583,7 @@ vector smoothstep(vector min, vector max, vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix smoothstep(matrix min, matrix max, matrix x) { __target_switch @@ -12423,7 +12597,7 @@ matrix smoothstep(matrix min, matrix max, matrix [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T sqrt(T x) { __target_switch @@ -12436,12 +12610,13 @@ T sqrt(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Sqrt $x }; + case wgsl: __intrinsic_asm "sqrt"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector sqrt(vector x) { __target_switch @@ -12452,6 +12627,7 @@ vector sqrt(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Sqrt $x }; + case wgsl: __intrinsic_asm "sqrt"; default: VECTOR_MAP_UNARY(T, N, sqrt, x); } @@ -12459,7 +12635,7 @@ vector sqrt(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix sqrt(matrix x) { __target_switch @@ -12473,7 +12649,7 @@ matrix sqrt(matrix x) // Step function __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T step(T y, T x) { __target_switch @@ -12484,6 +12660,7 @@ T step(T y, T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Step $y $x }; + case wgsl: __intrinsic_asm "step"; default: return x < y ? T(0.0f) : T(1.0f); } @@ -12491,7 +12668,7 @@ T step(T y, T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector step(vector y, vector x) { __target_switch @@ -12502,6 +12679,7 @@ vector step(vector y, vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Step $y $x }; + case wgsl: __intrinsic_asm "step"; default: VECTOR_MAP_BINARY(T, N, step, y, x); } @@ -12509,7 +12687,7 @@ vector step(vector y, vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix step(matrix y, matrix x) { __target_switch @@ -12523,7 +12701,7 @@ matrix step(matrix y, matrix x) // Tangent __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T tan(T x) { __target_switch @@ -12536,12 +12714,13 @@ T tan(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Tan $x }; + case wgsl: __intrinsic_asm "tan"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector tan(vector x) { __target_switch @@ -12552,6 +12731,7 @@ vector tan(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Tan $x }; + case wgsl: __intrinsic_asm "tan"; default: VECTOR_MAP_UNARY(T, N, tan, x); } @@ -12559,7 +12739,7 @@ vector tan(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix tan(matrix x) { __target_switch @@ -12573,7 +12753,7 @@ matrix tan(matrix x) // Hyperbolic tangent __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T tanh(T x) { __target_switch @@ -12586,12 +12766,13 @@ T tanh(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Tanh $x }; + case wgsl: __intrinsic_asm "tanh"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector tanh(vector x) { __target_switch @@ -12602,6 +12783,7 @@ vector tanh(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Tanh $x }; + case wgsl: __intrinsic_asm "tanh"; default: VECTOR_MAP_UNARY(T, N, tanh, x); } @@ -12609,7 +12791,7 @@ vector tanh(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix tanh(matrix x) { __target_switch @@ -12624,7 +12806,7 @@ matrix tanh(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T tanpi(T x) { __target_switch @@ -12637,7 +12819,7 @@ T tanpi(T x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector tanpi(vector x) { __target_switch @@ -12652,7 +12834,7 @@ vector tanpi(vector x) // Matrix transpose __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] [PreferRecompute] matrix transpose(matrix x) { @@ -12663,6 +12845,7 @@ matrix transpose(matrix x) case spirv: return spirv_asm { OpTranspose $$matrix result $x }; + case wgsl: __intrinsic_asm "transpose"; default: matrix result; for(int r = 0; r < M; ++r) @@ -12673,7 +12856,7 @@ matrix transpose(matrix x) } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] [PreferRecompute] matrix transpose(matrix x) { @@ -12684,6 +12867,7 @@ matrix transpose(matrix x) case spirv: return spirv_asm { OpTranspose $$matrix result $x }; + case wgsl: __intrinsic_asm "transpose"; default: matrix result; for (int r = 0; r < M; ++r) @@ -12694,7 +12878,7 @@ matrix transpose(matrix x) } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] [PreferRecompute] [OverloadRank(-1)] matrix transpose(matrix x) @@ -12706,6 +12890,7 @@ matrix transpose(matrix x) case spirv: return spirv_asm { OpTranspose $$matrix result $x }; + case wgsl: __intrinsic_asm "transpose"; default: matrix result; for (int r = 0; r < M; ++r) @@ -12718,7 +12903,7 @@ matrix transpose(matrix x) // Truncate to integer __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T trunc(T x) { __target_switch @@ -12731,12 +12916,13 @@ T trunc(T x) case spirv: return spirv_asm { OpExtInst $$T result glsl450 Trunc $x }; + case wgsl: __intrinsic_asm "trunc"; } } __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector trunc(vector x) { __target_switch @@ -12747,6 +12933,7 @@ vector trunc(vector x) case spirv: return spirv_asm { OpExtInst $$vector result glsl450 Trunc $x }; + case wgsl: __intrinsic_asm "trunc"; default: VECTOR_MAP_UNARY(T, N, trunc, x); } @@ -12754,7 +12941,7 @@ vector trunc(vector x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] matrix trunc(matrix x) { __target_switch diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index a173a332f4..96f5996a0c 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -111,6 +111,10 @@ def metal : target + textualTarget; /// [Target] def spirv : target; +/// Represents the WebGPU shading language code generation target. +/// [Target] +def wgsl : target + textualTarget; + // Capabilities that stand for target SPIR-V versions for the GLSL backend. // These are not compilation targets. We will convert `_spirv_*` to `glsl_spirv_*` during compilation. @@ -228,15 +232,15 @@ def _cuda_sm_9_0 : _cuda_sm_8_0; /// All code-gen targets /// [Compound] -alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv; +alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv | wgsl; /// All non-asm code-gen targets /// [Compound] -alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda; +alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda | wgsl; /// All slang-gfx compatible code-gen targets /// [Compound] -alias any_gfx_target = hlsl | metal | glsl | spirv; +alias any_gfx_target = hlsl | metal | glsl | spirv | wgsl; /// All "cpp syntax" code-gen targets /// [Compound] @@ -262,10 +266,18 @@ alias cpp_cuda_glsl_hlsl = cpp | cuda | glsl | hlsl; /// [Compound] alias cpp_cuda_glsl_hlsl_spirv = cpp | cuda | glsl | hlsl | spirv; +/// CPP, CUDA, GLSL, HLSL, SPIRV and WGSL code-gen targets +/// [Compound] +alias cpp_cuda_glsl_hlsl_spirv_wgsl = cpp | cuda | glsl | hlsl | spirv | wgsl; + /// CPP, CUDA, GLSL, HLSL, Metal and SPIRV code-gen targets /// [Compound] alias cpp_cuda_glsl_hlsl_metal_spirv = cpp | cuda | glsl | hlsl | metal | spirv; +/// CPP, CUDA, GLSL, HLSL, Metal, SPIRV and WGSL code-gen targets +/// [Compound] +alias cpp_cuda_glsl_hlsl_metal_spirv_wgsl = cpp | cuda | glsl | hlsl | metal | spirv | wgsl; + /// CPP, CUDA, and HLSL code-gen targets /// [Compound] alias cpp_cuda_hlsl = cpp | cuda | hlsl; @@ -318,6 +330,10 @@ alias cuda_glsl_spirv = cuda | glsl | spirv; /// [Compound] alias cuda_glsl_metal_spirv = cuda | glsl | metal | spirv; +/// CUDA, GLSL, Metal, SPIRV and WGSL code-gen targets +/// [Compound] +alias cuda_glsl_metal_spirv_wgsl = cuda | glsl | metal | spirv | wgsl; + /// CUDA, and HLSL code-gen targets /// [Compound] alias cuda_hlsl = cuda | hlsl; @@ -330,10 +346,18 @@ alias cuda_hlsl_spirv = cuda | hlsl | spirv; /// [Compound] alias glsl_hlsl_spirv = glsl | hlsl | spirv; +/// GLSL, HLSL, SPIRV and WGSL code-gen targets +/// [Compound] +alias glsl_hlsl_spirv_wgsl = glsl | hlsl | spirv | wgsl; + /// GLSL, HLSL, Metal, and SPIRV code-gen targets /// [Compound] alias glsl_hlsl_metal_spirv = glsl | hlsl | metal | spirv; +/// GLSL, HLSL, Metal, SPIRV and WGSL code-gen targets +/// [Compound] +alias glsl_hlsl_metal_spirv_wgsl = glsl | hlsl | metal | spirv | wgsl; + /// GLSL, Metal, and SPIRV code-gen targets /// [Compound] alias glsl_metal_spirv = glsl | metal | spirv; @@ -1178,6 +1202,7 @@ alias sm_4_0_version = _sm_4_0 | spirv_1_0 | _cuda_sm_2_0 | metal + | wgsl | cpp ; @@ -1198,6 +1223,7 @@ alias sm_4_1_version = _sm_4_1 | spirv_1_0 | _cuda_sm_6_0 | metal + | wgsl | cpp ; /// HLSL shader model 4.1 and related capabilities of other targets. @@ -1217,6 +1243,7 @@ alias sm_5_0_version = _sm_5_0 | spirv_1_0 | _cuda_sm_9_0 | metal + | wgsl | cpp ; /// HLSL shader model 5.0 and related capabilities of other targets. @@ -1686,6 +1713,7 @@ alias fragmentprocessing = fragment + _sm_5_0 | fragment + metal | fragment + cpp | fragment + cuda + | fragment + wgsl ; /// Capabilities required to use fragment derivative operations (with GLSL derivativecontrol) /// [Compound] diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 90b0e44f52..b213653387 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -779,8 +779,8 @@ namespace Slang SLANG_ASSERT(constraintDecl2); return TryUnifyTypes(constraints, unifyCtx, - constraintDecl1.getDecl()->getSup().type, - constraintDecl2.getDecl()->getSup().type); + getSup(m_astBuilder, constraintDecl1), + getSup(m_astBuilder, constraintDecl2)); } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 42a7a60e04..e56082aab9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -598,7 +598,7 @@ namespace Slang }; /// Shared state for a semantics-checking session. - struct SharedSemanticsContext + struct SharedSemanticsContext : public RefObject { Linkage* m_linkage = nullptr; diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 2c1efb067b..9b0b56ee24 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1,4 +1,5 @@ // slang-check-overload.cpp +#include "slang-ast-base.h" #include "slang-check-impl.h" #include "slang-lookup.h" @@ -1199,6 +1200,30 @@ namespace Slang return parent; } + void countDistanceToGloablScope(DeclRef const& leftDecl, + DeclRef const& rightDecl, + int& leftDistance, int& rightDistance) + { + leftDistance = 0; + rightDistance = 0; + + DeclRef decl = leftDecl; + while(decl) + { + leftDistance++; + decl = decl.getParent(); + } + + decl = rightDecl; + while(decl) + { + rightDistance++; + decl = decl.getParent(); + } + } + + // Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal. + // int SemanticsVisitor::CompareLookupResultItems( LookupResultItem const& left, LookupResultItem const& right) @@ -1283,6 +1308,63 @@ namespace Slang return -1; } + // If both are subscript decls, prefer the one that provides more + // accessors. + if (auto leftSubscriptDecl = left.declRef.as()) + { + if (auto rightSubscriptDecl = right.declRef.as()) + { + auto leftAccessorCount = leftSubscriptDecl.getDecl()->getMembersOfType().getCount(); + auto rightAccessorCount = rightSubscriptDecl.getDecl()->getMembersOfType().getCount(); + auto decl1IsSubsetOfDecl2 = [=](SubscriptDecl* decl1, SubscriptDecl* decl2) + { + for (auto accessorDecl1 : decl1->getMembersOfType()) + { + bool found = false; + for (auto accessorDecl2 : decl2->getMembersOfType()) + { + if (accessorDecl1->astNodeType == accessorDecl2->astNodeType) + { + found = true; + break; + } + } + if (!found) + return false; + } + return true; + }; + if (leftAccessorCount > rightAccessorCount + && decl1IsSubsetOfDecl2(rightSubscriptDecl.getDecl(), leftSubscriptDecl.getDecl())) + { + return -1; + } + else if (rightAccessorCount > leftAccessorCount + && decl1IsSubsetOfDecl2(leftSubscriptDecl.getDecl(), rightSubscriptDecl.getDecl())) + { + return 1; + } + } + } + + // We need to consider the distance of the declarations to the global scope to resolve this case: + // float f(float x); + // struct S + // { + // float f(float x); + // float g(float y) { return f(y); } // will call S::f() instead of ::f() + // } + // We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope, + // because this function will only choose the valid candidates. So if there is situation like this: + // void main() { S s; s.f(1.0);} or + // struct T { float g(y) { f(y); } }, there won't be ambiguity. + // So we just need to count which declaration is farther from the global scope and favor the farther one. + int leftDistance = 0; + int rightDistance = 0; + countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance); + if (leftDistance != rightDistance) + return leftDistance > rightDistance ? -1 : 1; + // TODO: We should generalize above rules such that in a tie a declaration // A::m is better than B::m when all other factors are equal and // A inherits from B. diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 4bb420fa7a..541085b4ee 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -1715,6 +1715,7 @@ namespace Slang case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: case CodeGenTarget::Metal: + case CodeGenTarget::WGSL: { RefPtr extensionTracker = _newExtensionTracker(target); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b8ee4dc9cd..0c788ae182 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -40,7 +40,9 @@ namespace Slang { struct PathInfo; - struct IncludeHandler; + struct IncludeHandler; + struct SharedSemanticsContext; + class ProgramLayout; class PtrType; class TargetProgram; @@ -94,6 +96,7 @@ namespace Slang Metal = SLANG_METAL, MetalLib = SLANG_METAL_LIB, MetalLibAssembly = SLANG_METAL_LIB_ASM, + WGSL = SLANG_WGSL, CountOf = SLANG_TARGET_COUNT_OF, }; @@ -2169,6 +2172,11 @@ namespace Slang DeclRef declRef, List argExprs, DiagnosticSink* sink); + + DeclRef specializeWithArgTypes( + DeclRef funcDeclRef, + List argTypes, + DiagnosticSink* sink); DiagnosticSink::Flags diagnosticSinkFlags = 0; @@ -2182,6 +2190,9 @@ namespace Slang m_retainedSession = nullptr; } + // Get shared semantics information for reflection purposes. + SharedSemanticsContext* getSemanticsForReflection(); + private: /// The global Slang library session that this linkage is a child of Session* m_session = nullptr; @@ -2235,6 +2246,8 @@ namespace Slang List m_specializedTypes; + RefPtr m_semanticsForReflection; + }; /// Shared functionality between front- and back-end compile requests. diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index c9dd6d9c8a..e32a738c7f 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -483,6 +483,10 @@ static DocMarkdownWriter::Requirement _getRequirementFromTargetToken(const Token { return Requirement{ CodeGenTarget::Metal, targetName }; } + else if (isCapabilityDerivedFrom(targetCap, CapabilityAtom::wgsl)) + { + return Requirement{ CodeGenTarget::WGSL, targetName }; + } return Requirement{ CodeGenTarget::Unknown, String() }; } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 1893929f89..caf3613a71 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -95,6 +95,10 @@ struct CLikeSourceEmitter::ComputeEmitActionsContext { return SourceLanguage::Metal; } + case CodeGenTarget::WGSL: + { + return SourceLanguage::WGSL; + } } } @@ -151,7 +155,7 @@ void CLikeSourceEmitter::ensureTypePrelude(IRType* type) } } -void CLikeSourceEmitter::emitDeclarator(DeclaratorInfo* declarator) +void CLikeSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator) { if (!declarator) return; @@ -341,13 +345,18 @@ void CLikeSourceEmitter::_emitPostfixTypeAttr(IRAttr* attr) // we may need to handle it here. } +void CLikeSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) +{ + emitSimpleType(type); + emitDeclarator(declarator); +} + void CLikeSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) { switch (type->getOp()) { default: - emitSimpleType(type); - emitDeclarator(declarator); + emitSimpleTypeAndDeclarator(type, declarator); break; case kIROp_RateQualifiedType: @@ -648,7 +657,7 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo bool needParens = (prec.leftPrecedence <= outerPrec.leftPrecedence) || (prec.rightPrecedence <= outerPrec.rightPrecedence); - // While Slang correctly removes some of parentheses, DXC prints warnings + // While Slang correctly removes some of parentheses, many compilers print warnings // for common mistakes when parentheses are not used with certain combinations // of the operations. We emit parentheses to avoid the warnings. // @@ -676,6 +685,12 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo { needParens = true; } + // a + b & c => (a + b) & c + else if (prec.rightPrecedence == EPrecedence::kEPrecedence_Additive_Right + && outerPrec.rightPrecedence == EPrecedence::kEPrecedence_BitAnd_Left) + { + needParens = true; + } if (needParens) { @@ -1657,11 +1672,16 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) return true; } +bool CLikeSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* /* inst */) +{ + return doesTargetSupportPtrTypes(); +} + void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& outerPrec) { EmitOpInfo newOuterPrec = outerPrec; - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { switch (inst->getOp()) { @@ -1760,7 +1780,7 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& void CLikeSourceEmitter::emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec) { - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { auto prec = getInfo(EmitOp::Prefix); auto newOuterPrec = outerPrec; @@ -1842,7 +1862,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) emitRateQualifiers(inst); - if(as(inst->getParent())) + bool isConstant(as(inst->getParent())); + if(isConstant) { // "Ordinary" instructions at module scope are constants @@ -1857,6 +1878,9 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) case SourceLanguage::Metal: m_writer->emit("constant "); break; + case SourceLanguage::WGSL: + // This is handled by emitVarKeyword, below + break; default: m_writer->emit("const "); break; @@ -1864,6 +1888,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) } + emitVarKeyword(type, isConstant); + emitType(type, getName(inst)); m_writer->emit(" = "); } @@ -2297,7 +2323,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO IRFieldAddress* ii = (IRFieldAddress*) inst; - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { auto prec = getInfo(EmitOp::Prefix); needClose = maybeEmitParens(outerPrec, prec); @@ -3117,6 +3143,8 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store) void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type) { + emitVarKeyword(type, /* isConstant */ false); + emitType(type, getName(inst)); // On targets that support empty initializers, we will emit it. @@ -3178,6 +3206,20 @@ void CLikeSourceEmitter::emitLayoutSemantics(IRInst* inst, char const* uniformSe emitLayoutSemanticsImpl(inst, uniformSemanticSpelling, EmitLayoutSemanticOption::kPostType); } +void CLikeSourceEmitter::emitSwitchCaseSelectorsImpl(IRBasicType *const /* switchCondition */, const SwitchRegion::Case *const currentCase, const bool isDefault) +{ + for(auto caseVal : currentCase->values) + { + m_writer->emit("case "); + emitOperand(caseVal, getInfo(EmitOp::General)); + m_writer->emit(":\n"); + } + if(isDefault) + { + m_writer->emit("default:\n"); + } +} + void CLikeSourceEmitter::emitRegion(Region* inRegion) { // We will use a loop so that we can process sequential (simple) @@ -3333,17 +3375,9 @@ void CLikeSourceEmitter::emitRegion(Region* inRegion) auto defaultCase = switchRegion->defaultCase; for(auto currentCase : switchRegion->cases) { - for(auto caseVal : currentCase->values) - { - m_writer->emit("case "); - emitOperand(caseVal, getInfo(EmitOp::General)); - m_writer->emit(":\n"); - } - if(currentCase.Ptr() == defaultCase) - { - m_writer->emit("default:\n"); - } - + const bool isDefault {currentCase.Ptr() == defaultCase}; + IRBasicType *const switchConditionType {as(switchRegion->getCondition()->getDataType())}; + emitSwitchCaseSelectors(switchConditionType, currentCase.Ptr(), isDefault); m_writer->indent(); m_writer->emit("{\n"); m_writer->indent(); @@ -3449,9 +3483,16 @@ void CLikeSourceEmitter::emitSimpleFuncParamsImpl(IRFunc* func) m_writer->emit(")"); } -void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) +void CLikeSourceEmitter::emitFuncHeaderImpl(IRFunc* func) { auto resultType = func->getResultType(); + auto name = getName(func); + emitType(resultType, name); + emitSimpleFuncParamsImpl(func); +} + +void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) +{ // Deal with decorations that need // to be emitted as attributes @@ -3467,12 +3508,8 @@ void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) emitFunctionPreambleImpl(func); - auto name = getName(func); - emitFuncDecorations(func); - - emitType(resultType, name); - emitSimpleFuncParamsImpl(func); + emitFuncHeader(func); emitSemantics(func); // TODO: encode declaration vs. definition @@ -3688,6 +3725,11 @@ void CLikeSourceEmitter::emitStruct(IRStructType* structType) m_writer->emit(";\n\n"); } +void CLikeSourceEmitter::emitStructDeclarationSeparatorImpl() +{ + m_writer->emit(";"); +} + void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout) { m_writer->emit("\n{\n"); @@ -3716,11 +3758,13 @@ void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, b emitPackOffsetModifier(fieldKey, fieldType, packOffsetDecoration); } } + emitStructFieldAttributes(structType, ff); emitMemoryQualifiers(fieldKey); emitType(fieldType, getName(fieldKey)); emitSemantics(fieldKey, allowOffsetLayout); emitPostDeclarationAttributesForType(fieldType); - m_writer->emit(";\n"); + emitStructDeclarationSeparator(); + m_writer->emit("\n"); } m_writer->dedent(); @@ -3931,6 +3975,8 @@ void CLikeSourceEmitter::emitParameterGroup(IRGlobalParam* varDecl, IRUniformPar emitParameterGroupImpl(varDecl, type); } +void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, bool /* isConstant */) {} + void CLikeSourceEmitter::emitVar(IRVar* varDecl) { auto allocatedType = varDecl->getDataType(); @@ -3969,6 +4015,8 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl) #endif emitRateQualifiersAndAddressSpace(varDecl); + emitVarKeyword(varType, /* isConstant */ false); + emitType(varType, getName(varDecl)); emitSemantics(varDecl); @@ -4099,6 +4147,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl) emitVarModifiers(layout, varDecl, varType); emitRateQualifiersAndAddressSpace(varDecl); + emitVarKeyword(varType, /* isConstant */ true); emitType(varType, getName(varDecl)); // TODO: These shouldn't be needed for ordinary @@ -4172,7 +4221,8 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) emitDecorationLayoutSemantics(varDecl, "register"); emitRateQualifiersAndAddressSpace(varDecl); - emitType(varType, getName(varDecl)); + emitVarKeyword(varType, /* isConstant */ false); + emitGlobalParamType(varType, getName(varDecl)); emitSemantics(varDecl); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 00ad156d1d..be769f31f9 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -248,7 +248,8 @@ class CLikeSourceEmitter: public SourceEmitterBase // void ensureTypePrelude(IRType* type); - void emitDeclarator(DeclaratorInfo* declarator); + void emitDeclarator(DeclaratorInfo* declarator) {emitDeclaratorImpl(declarator);} + virtual void emitDeclaratorImpl(DeclaratorInfo* declarator); void emitType(IRType* type, const StringSliceLoc* nameLoc) { emitTypeImpl(type, nameLoc); } void emitType(IRType* type, Name* name); @@ -256,6 +257,7 @@ class CLikeSourceEmitter: public SourceEmitterBase void emitType(IRType* type); void emitType(IRType* type, Name* name, SourceLoc const& nameLoc); void emitType(IRType* type, NameLoc const& nameAndLoc); + virtual void emitGlobalParamType(IRType* type, String const& name) {emitType(type, name);} bool hasExplicitConstantBufferOffset(IRInst* cbufferType); bool isSingleElementConstantBuffer(IRInst* cbufferType); bool shouldForceUnpackConstantBufferElements(IRInst* cbufferType); @@ -368,8 +370,11 @@ class CLikeSourceEmitter: public SourceEmitterBase /// Emit high-level statements for the body of a function. void emitFunctionBody(IRGlobalValueWithCode* code); + void emitFuncHeader(IRFunc* func) { emitFuncHeaderImpl(func); } void emitSimpleFunc(IRFunc* func) { emitSimpleFuncImpl(func); } + void emitSwitchCaseSelectors(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault) {emitSwitchCaseSelectorsImpl(switchConditionType, currentCase, isDefault);} + void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); } void emitFuncDecl(IRFunc* func); @@ -394,10 +399,14 @@ class CLikeSourceEmitter: public SourceEmitterBase void emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout); void emitClass(IRClassType* structType); + void emitStructDeclarationSeparator() {emitStructDeclarationSeparatorImpl();} + virtual void emitStructDeclarationSeparatorImpl(); + /// Emit type attributes that should appear after, e.g., a `struct` keyword void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); } virtual void emitMemoryQualifiers(IRInst* /*varInst*/) {}; + virtual void emitStructFieldAttributes(IRStructType * /* structType */, IRStructField * /* field */) {}; void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout); void emitMeshShaderModifiers(IRInst* varInst); virtual void emitPackOffsetModifier(IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/) {}; @@ -421,6 +430,7 @@ class CLikeSourceEmitter: public SourceEmitterBase void emitGlobalInst(IRInst* inst); virtual void emitGlobalInstImpl(IRInst* inst); + virtual bool isPointerSyntaxRequiredImpl(IRInst* inst); void ensureInstOperand(ComputeEmitActionsContext* ctx, IRInst* inst, EmitAction::Level requiredLevel = EmitAction::Level::Definition); @@ -486,6 +496,11 @@ class CLikeSourceEmitter: public SourceEmitterBase virtual void emitPreModuleImpl(); virtual void emitPostModuleImpl(); + virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator); + void emitSimpleTypeAndDeclarator(IRType* type, DeclaratorInfo* declarator) {emitSimpleTypeAndDeclaratorImpl(type, declarator);}; + virtual void emitVarKeywordImpl(IRType * type, bool isConstant); + void emitVarKeyword(IRType * type, bool isConstant) {emitVarKeywordImpl(type, isConstant);} + virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); }; virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } @@ -501,6 +516,7 @@ class CLikeSourceEmitter: public SourceEmitterBase virtual void emitTypeImpl(IRType* type, const StringSliceLoc* nameLoc); virtual void emitSimpleValueImpl(IRInst* inst); virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink); + virtual void emitFuncHeaderImpl(IRFunc* func); virtual void emitSimpleFuncImpl(IRFunc* func); virtual void emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec); virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec); @@ -511,6 +527,7 @@ class CLikeSourceEmitter: public SourceEmitterBase virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { SLANG_UNUSED(decl); } virtual void emitIfDecorationsImpl(IRIfElse* ifInst) { SLANG_UNUSED(ifInst); } virtual void emitSwitchDecorationsImpl(IRSwitch* switchInst) { SLANG_UNUSED(switchInst); } + virtual void emitSwitchCaseSelectorsImpl(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault); virtual void emitFuncDecorationImpl(IRDecoration* decoration) { SLANG_UNUSED(decoration); } virtual void emitLivenessImpl(IRInst* inst); diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp new file mode 100644 index 0000000000..40d8ace91a --- /dev/null +++ b/source/slang/slang-emit-wgsl.cpp @@ -0,0 +1,1023 @@ +#include "slang-emit-wgsl.h" + +// A note on row/column "terminology reversal". +// +// This is an "terminology reversing" implementation in the sense that +// * "column" in Slang code maps to "row" in the generated WGSL code, and +// * "row" in Slang code maps to "column" in the generated WGSL code. +// +// This means that matrices in Slang code end up getting translated to +// matrices that actually represent the transpose of what the Slang matrix +// represented. +// Both API's adopt the standard matrix multiplication convention whereby the +// column count of the matrix on the left hand side needs to match row count of +// the matrix on the right hand side. +// For these reasons, and due to the fact that (M_1 ... M_n)^T = M_n^T ... M_1^T, +// the order of matrix (and vector-matrix products) products must also reversed +// in the WGSL code. +// +// This may lead to confusion (which is why this note is referenced in several +// places), but the benefit of doing this is that the generated WGSL code is +// simpler to generate and should be faster to compile. +// A "terminology preserving" implementation would have to generate lots of +// 'transpose' calls, or else perform more complicated transformations that +// end up duplicating expressions many times. + +namespace Slang { + +void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl( + IRBasicType *const switchConditionType, + const SwitchRegion::Case *const currentCase, const bool isDefault + ) +{ + // WGSL has special syntax for blocks sharing case labels: + // "case 2, 3, 4: ...;" instead of the C-like syntax + // "case 2: case 3: case 4: ...;". + + m_writer->emit("case "); + for (auto caseVal : currentCase->values) + { + // TODO: Fix this in the front-end [1], remove the if-path and just do the else-path. + // We can't do that at the moment because it would break Falcor [2]. + // [1] https://github.com/shader-slang/slang/pull/5025/commits/a32156ef52f43b8503b2c77f2f1d51220ab9bdea + // [2] https://github.com/shader-slang/slang/pull/5025#issuecomment-2334495120 + if (caseVal->getOp() == kIROp_IntLit) + { + auto caseLitInst = static_cast(caseVal); + IRBasicType *const caseInstType = as(caseLitInst->getDataType()); + // WGSL doesn't allow switch condition and case type mismatches, see [1]. + // Thus we need to insert explicit conversions. + // Doing a wrapping cast will match Slang's de facto semantics, according to + // [2]. + // (This is just a bitcast, assuming a two's complement representation.) + // [1] https://www.w3.org/TR/WGSL/#switch-statement + // [2] https://github.com/shader-slang/slang/issues/4921 + const bool needBitcast = + caseInstType->getBaseType() != switchConditionType->getBaseType(); + if (needBitcast) + { + m_writer->emit("bitcast<"); + emitType(switchConditionType); + m_writer->emit(">("); + } + emitOperand(caseVal, getInfo(EmitOp::General)); + if (needBitcast) + { + m_writer->emit(")"); + } + } + else + { + emitOperand(caseVal, getInfo(EmitOp::General)); + } + m_writer->emit(", "); + } + if (isDefault) + { + m_writer->emit("default, "); + } + m_writer->emit(":\n"); +} + +void WGSLSourceEmitter::emitParameterGroupImpl( + IRGlobalParam* varDecl, IRUniformParameterGroupType* type +) +{ + auto varLayout = getVarLayout(varDecl); + SLANG_RELEASE_ASSERT(varLayout); + + for (auto attr : varLayout->getOffsetAttrs()) + { + + const LayoutResourceKind kind = attr->getResourceKind(); + switch (kind) + { + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + m_writer->emit("@location("); + m_writer->emit(attr->getOffset()); + m_writer->emit(")"); + if (attr->getSpace()) + { + // TODO: Not sure what 'space' should map to in WGSL + SLANG_ASSERT(false); + } + break; + + case LayoutResourceKind::SpecializationConstant: + // TODO: + // Consider moving to a differently named function. + // This is not technically an attribute, but a declaration. + // + // https://www.w3.org/TR/WGSL/#override-decls + m_writer->emit("override"); + break; + + case LayoutResourceKind::Uniform: + case LayoutResourceKind::ConstantBuffer: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + case LayoutResourceKind::SamplerState: + case LayoutResourceKind::DescriptorTableSlot: + m_writer->emit("@binding("); + m_writer->emit(attr->getOffset()); + m_writer->emit(") "); + m_writer->emit("@group("); + m_writer->emit(attr->getSpace()); + m_writer->emit(") "); + break; + + } + + } + + auto elementType = type->getElementType(); + m_writer->emit("var "); + m_writer->emit(getName(varDecl)); + m_writer->emit(" : "); + emitType(elementType); + m_writer->emit(";\n"); +} + +void WGSLSourceEmitter::emitEntryPointAttributesImpl( + IRFunc* irFunc, IREntryPointDecoration* entryPointDecor + ) +{ + auto stage = entryPointDecor->getProfile().getStage(); + + switch (stage) + { + + case Stage::Fragment: + m_writer->emit("@fragment\n"); + break; + case Stage::Vertex: + m_writer->emit("@vertex\n"); + break; + + case Stage::Compute: + { + m_writer->emit("@compute\n"); + + { + Int sizeAlongAxis[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis); + + m_writer->emit("@workgroup_size("); + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) + { + if (ii != 0) + m_writer->emit(", "); + m_writer->emit(sizeAlongAxis[ii]); + } + m_writer->emit(")\n"); + } + } + break; + + default: + SLANG_ABORT_COMPILATION("unsupported stage."); + } + +} + +// This is 'function_header' from the WGSL specification +void WGSLSourceEmitter::emitFuncHeaderImpl(IRFunc* func) +{ + Slang::IRType * resultType = func->getResultType(); + auto name = getName(func); + + m_writer->emit("fn "); + m_writer->emit(name); + + emitSimpleFuncParamsImpl(func); + + // An absence of return type is expressed by skipping the optional '->' part of the + // header. + if (resultType->getOp() != kIROp_VoidType) + { + m_writer->emit(" -> "); + emitType(resultType); + } +} + +void WGSLSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) +{ + if (auto sysSemanticDecor = param->findDecoration()) + { + m_writer->emit("@builtin("); + m_writer->emit(sysSemanticDecor->getSemantic()); + m_writer->emit(")"); + } + + CLikeSourceEmitter::emitSimpleFuncParamImpl(param); +} + +void WGSLSourceEmitter::emitMatrixType( + IRType *const elementType, const IRIntegerValue& rowCountWGSL, + const IRIntegerValue& colCountWGSL + ) +{ + // WGSL uses CxR convention + m_writer->emit("mat"); + m_writer->emit(colCountWGSL); + m_writer->emit("x"); + m_writer->emit(rowCountWGSL); + m_writer->emit("<"); + emitType(elementType); + m_writer->emit(">"); +} + +void WGSLSourceEmitter::emitStructDeclarationSeparatorImpl() +{ + m_writer->emit(","); +} + +static bool isPowerOf2(const uint32_t n) +{ + return (n != 0U) && ((n - 1U) & n) == 0U; +} + +void WGSLSourceEmitter::emitStructFieldAttributes( + IRStructType * structType, IRStructField * field + ) +{ + // Tint emits errors unless we explicitly spell out the layout in some cases, so emit + // offset and align attribtues for all fields. + IRSizeAndAlignmentDecoration *const sizeAndAlignmentDecoration = + structType->findDecoration(); + // NullDifferential struct doesn't have size and alignment decoration + if (sizeAndAlignmentDecoration == nullptr) + return; + SLANG_ASSERT(sizeAndAlignmentDecoration->getAlignment() > IRIntegerValue{0}); + SLANG_ASSERT( + sizeAndAlignmentDecoration->getAlignment() <= IRIntegerValue{UINT32_MAX} + ); + const uint32_t structAlignment = + static_cast(sizeAndAlignmentDecoration->getAlignment()); + IROffsetDecoration *const fieldOffsetDecoration = + field->findDecoration(); + SLANG_ASSERT(fieldOffsetDecoration->getOffset() >= IRIntegerValue{0}); + SLANG_ASSERT(fieldOffsetDecoration->getOffset() <= IRIntegerValue{UINT32_MAX}); + SLANG_ASSERT(isPowerOf2(structAlignment)); + const uint32_t fieldOffset = + static_cast(fieldOffsetDecoration->getOffset()); + // Alignment is GCD(fieldOffset, structAlignment) + // TODO: Use builtin/intrinsic (e.g. __builtin_ffs) + uint32_t fieldAlignment = 1U; + while (((fieldAlignment & (structAlignment | fieldOffset)) == 0U)) + fieldAlignment = fieldAlignment << 1U; + + m_writer->emit("@align("); + m_writer->emit(fieldAlignment); + m_writer->emit(")"); +} + +bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst) +{ + // Structured buffers are mapped to 'array' types, which don't need dereferencing + if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) + return false; + + // Don't emit "->" to access fields in resource structs + if (inst->getOp() == kIROp_FieldAddress) + return false; + + // Don't emit "*" to access fields in resource structs + if (inst->getOp() == kIROp_GlobalParam) + return false; + + // Emit 'globalVar' instead of "*&globalVar" + if (inst->getOp() == kIROp_GlobalVar) + return false; + + return true; +} + +void WGSLSourceEmitter::emit(const AddressSpace addressSpace) +{ + switch (addressSpace) + { + case AddressSpace::Uniform: + m_writer->emit("uniform"); + break; + + case AddressSpace::StorageBuffer: + m_writer->emit("storage"); + break; + + case AddressSpace::Generic: + m_writer->emit("function"); + break; + + case AddressSpace::ThreadLocal: + m_writer->emit("private"); + break; + + case AddressSpace::GroupShared: + m_writer->emit("workgroup"); + break; + } +} + +void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type) +{ + switch (type->getOp()) + { + + case kIROp_HLSLRWStructuredBufferType: + { + auto structuredBufferType = as(type); + m_writer->emit("ptr<"); + emit(AddressSpace::StorageBuffer); + m_writer->emit(", "); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + m_writer->emit(", read_write"); + m_writer->emit(">"); + } + break; + + case kIROp_HLSLStructuredBufferType: + { + auto structuredBufferType = as(type); + m_writer->emit("ptr<"); + emit(AddressSpace::StorageBuffer); + m_writer->emit(", "); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + m_writer->emit(", read"); + m_writer->emit(">"); + } + break; + + case kIROp_VoidType: + { + // There is no void type in WGSL. + // A return type of "void" is expressed by skipping the end part of the + // 'function_header' term: + // " + // function_header : + // 'fn' ident '(' param_list ? ')' + // ( '->' attribute * template_elaborated_ident ) ? + // " + // In other words, in WGSL we should never even get to the point where we're + // asking to emit 'void'. + SLANG_UNEXPECTED("'void' type emitted"); + return; + } + + case kIROp_FloatType: + m_writer->emit("f32"); + break; + case kIROp_DoubleType: + // There is no "f64" type in WGSL + SLANG_UNEXPECTED("'double' type emitted"); + break; + case kIROp_Int8Type: + case kIROp_UInt8Type: + // There is no "[i|u]8" type in WGSL + SLANG_UNEXPECTED("8 bit integer type emitted"); + break; + case kIROp_HalfType: + m_f16ExtensionEnabled = true; + m_writer->emit("f16"); + break; + case kIROp_BoolType: + m_writer->emit("bool"); + break; + case kIROp_IntType: + m_writer->emit("i32"); + break; + case kIROp_UIntType: + m_writer->emit("u32"); + break; + case kIROp_UInt64Type: + { + m_writer->emit(getDefaultBuiltinTypeName(type->getOp())); + return; + } + case kIROp_Int16Type: + case kIROp_UInt16Type: + SLANG_UNEXPECTED("16 bit integer value emitted"); + return; + case kIROp_Int64Type: + case kIROp_IntPtrType: + m_writer->emit("i64"); + return; + case kIROp_UIntPtrType: + m_writer->emit("u64"); + return; + case kIROp_StructType: + m_writer->emit(getName(type)); + return; + + case kIROp_VectorType: + { + auto vecType = (IRVectorType*)type; + emitVectorTypeNameImpl( + vecType->getElementType(), getIntVal(vecType->getElementCount()) + ); + return; + } + case kIROp_MatrixType: + { + auto matType = (IRMatrixType*)type; + // We map matrices in Slang to WGSL matrices that represent the transpose. + // (See note on "terminology reversal".) + const IRIntegerValue colCountWGSL = getIntVal(matType->getRowCount()); + const IRIntegerValue rowCountWGSL = getIntVal(matType->getColumnCount()); + emitMatrixType(matType->getElementType(), rowCountWGSL, colCountWGSL); + return; + } + case kIROp_SamplerStateType: + { + m_writer->emit("sampler"); + return; + } + + case kIROp_SamplerComparisonStateType: + { + m_writer->emit("sampler_comparison"); + return; + } + + case kIROp_PtrType: + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + { + auto ptrType = cast(type); + m_writer->emit("ptr<"); + emit((AddressSpace)ptrType->getAddressSpace()); + m_writer->emit(", "); + emitType((IRType*)ptrType->getValueType()); + m_writer->emit(">"); + return; + } + + case kIROp_ArrayType: + { + m_writer->emit("array<"); + emitType((IRType*)type->getOperand(0)); + m_writer->emit(", "); + emitVal(type->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(">"); + return; + } + default: + break; + + } + +} + +void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout) +{ + + for (auto attr : layout->getOffsetAttrs()) + { + LayoutResourceKind kind = attr->getResourceKind(); + + // TODO: + // This is not correct. For the moment this is just here as a hack to make + // @binding and @group unique, so that we can pass WGSL compile tests. + // This will have to be revisited when we actually want to supply resources to + // shaders. + if (kind == LayoutResourceKind::DescriptorTableSlot) + { + m_writer->emit("@binding("); + m_writer->emit(attr->getOffset()); + m_writer->emit(") "); + m_writer->emit("@group("); + m_writer->emit(attr->getSpace()); + m_writer->emit(") "); + + return; + } + } + +} + +void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, const bool isConstant) +{ + if (isConstant) + m_writer->emit("const"); + else + m_writer->emit("var"); + if (type->getOp() == kIROp_HLSLRWStructuredBufferType) + { + m_writer->emit("<"); + m_writer->emit("storage, read_write"); + m_writer->emit(">"); + } + else if (type->getOp() == kIROp_HLSLStructuredBufferType) + { + m_writer->emit("<"); + m_writer->emit("storage, read"); + m_writer->emit(">"); + } +} + +void WGSLSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) +{ + // C-like languages bake array-ness, pointer-ness and reference-ness into the + // declarator, which happens in the default _emitType implementation. + // WGSL on the other hand, don't have special syntax -- these are just types. + switch (type->getOp()) + { + case kIROp_ArrayType: + case kIROp_AttributedType: + case kIROp_UnsizedArrayType: + emitSimpleTypeAndDeclarator(type, declarator); + break; + default: + CLikeSourceEmitter::_emitType(type, declarator); + break; + } +} + +void WGSLSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator) +{ + if (!declarator) return; + + m_writer->emit(" "); + + switch (declarator->flavor) + { + case DeclaratorInfo::Flavor::Name: + { + auto nameDeclarator = (NameDeclaratorInfo*)declarator; + m_writer->emitName(*nameDeclarator->nameAndLoc); + } + break; + + case DeclaratorInfo::Flavor::SizedArray: + { + // Sized arrays are just types (array) in WGSL -- they are not + // supported at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Sized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::UnsizedArray: + { + // Unsized arrays are just types (array) in WGSL -- they are not + // supported at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Unsized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::Ptr: + { + // Pointers (ptr) are just types in WGSL -- they are not supported at + // the syntax level + // https://www.w3.org/TR/WGSL/#ref-ptr-types + SLANG_UNEXPECTED("Pointer declarator"); + } + break; + + case DeclaratorInfo::Flavor::Ref: + { + // References (ref) are just types in WGSL -- they are not supported + // at the syntax level + // https://www.w3.org/TR/WGSL/#ref-ptr-types + SLANG_UNEXPECTED("Reference declarator"); + } + break; + + case DeclaratorInfo::Flavor::LiteralSizedArray: + { + // Sized arrays are just types (array) in WGSL -- they are not supported + // at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Literal-sized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::Attributed: + { + auto attributedDeclarator = (AttributedDeclaratorInfo*)declarator; + auto instWithAttributes = attributedDeclarator->instWithAttributes; + for (auto attr : instWithAttributes->getAllAttrs()) + { + _emitPostfixTypeAttr(attr); + } + emitDeclarator(attributedDeclarator->next); + } + break; + + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unknown declarator flavor"); + break; + } +} + +void WGSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl( + IRType* type, DeclaratorInfo* declarator + ) +{ + if (declarator) + { + emitDeclarator(declarator); + m_writer->emit(" : "); + } + emitSimpleType(type); +} + +void WGSLSourceEmitter::emitSimpleValueImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + { + auto litInst = static_cast(inst); + + IRBasicType* type = as(inst->getDataType()); + if (type) + { + switch (type->getBaseType()) + { + default: + + case BaseType::Int8: + case BaseType::UInt8: + { + SLANG_UNEXPECTED("8 bit integer value emitted"); + break; + } + case BaseType::Int16: + case BaseType::UInt16: + { + SLANG_UNEXPECTED("16 bit integer value emitted"); + break; + } + case BaseType::Int: + { + m_writer->emit("i32("); + m_writer->emit(int32_t(litInst->value.intVal)); + m_writer->emit(")"); + return; + } + case BaseType::UInt: + { + m_writer->emit("u32("); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); + m_writer->emit(")"); + break; + } + case BaseType::Int64: + { + m_writer->emit("i64("); + m_writer->emitInt64(int64_t(litInst->value.intVal)); + m_writer->emit(")"); + break; + } + case BaseType::UInt64: + { + m_writer->emit("u64("); + SLANG_COMPILE_TIME_ASSERT( + sizeof(litInst->value.intVal) >= sizeof(uint64_t) + ); + m_writer->emitUInt64(uint64_t(litInst->value.intVal)); + m_writer->emit(")"); + break; + } + case BaseType::IntPtr: + { +#if SLANG_PTR_IS_64 + m_writer->emit("i64("); + m_writer->emitInt64(int64_t(litInst->value.intVal)); + m_writer->emit(")"); +#else + m_writer->emit("i32("); + m_writer->emit(int(litInst->value.intVal)); + m_writer->emit(")"); +#endif + break; + } + case BaseType::UIntPtr: + { +#if SLANG_PTR_IS_64 + m_writer->emit("u64("); + m_writer->emitUInt64(uint64_t(litInst->value.intVal)); + m_writer->emit(")"); +#else + m_writer->emit("u32("); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); + m_writer->emit(")"); +#endif + break; + } + + } + } + else + { + // If no type... just output what we have + m_writer->emit(litInst->value.intVal); + } + break; + } + + case kIROp_FloatLit: + { + auto litInst = static_cast(inst); + + IRBasicType* type = as(inst->getDataType()); + if (type) + { + switch (type->getBaseType()) + { + default: + + case BaseType::Half: + { + m_writer->emit(litInst->value.floatVal); + m_writer->emit("h"); + m_f16ExtensionEnabled = true; + } + break; + + case BaseType::Float: + { + m_writer->emit(litInst->value.floatVal); + m_writer->emit("f"); + } + break; + + case BaseType::Double: + { + // There is not "f64" in WGSL + SLANG_UNEXPECTED("'double' type emitted"); + } + break; + } + } + else + { + // If no type... just output what we have + m_writer->emit(litInst->value.floatVal); + } + } + break; + + case kIROp_BoolLit: + { + bool val = ((IRConstant*)inst)->value.intVal != 0; + m_writer->emit(val ? "true" : "false"); + } + break; + + default: + SLANG_UNIMPLEMENTED_X("val case for emit"); + break; + } + + +} + +void WGSLSourceEmitter::emitParamTypeImpl(IRType* type, const String& name) +{ + emitType(type, name); +} + +bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) +{ + EmitOpInfo outerPrec = inOuterPrec; + + switch (inst->getOp()) + { + + case kIROp_MakeVectorFromScalar: + { + // In WGSL this is done by calling the vec* overloads listed in [1] + // [1] https://www.w3.org/TR/WGSL/#value-constructor-builtin-function + emitType(inst->getDataType()); + m_writer->emit("("); + auto prec = getInfo(EmitOp::Prefix); + emitOperand(inst->getOperand(0), rightSide(outerPrec, prec)); + m_writer->emit(")"); + return true; + } + break; + + case kIROp_BitCast: + { + // In WGSL there is a built-in bitcast function! + // https://www.w3.org/TR/WGSL/#bitcast-builtin + m_writer->emit("bitcast"); + m_writer->emit("<"); + emitType(inst->getDataType()); + m_writer->emit(">"); + m_writer->emit("("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + + case kIROp_MakeArray: + case kIROp_MakeStruct: + { + // It seems there are currently no designated initializers in WGSL. + // Similarly for array initializers. + // https://github.com/gpuweb/gpuweb/issues/4210 + + // There is a constructor named like the struct/array type itself + auto type = inst->getDataType(); + emitType(type); + m_writer->emit("( "); + UInt argCount = inst->getOperandCount(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) m_writer->emit(", "); + emitOperand(inst->getOperand(aa), getInfo(EmitOp::General)); + } + m_writer->emit(" )"); + + return true; + } + break; + + case kIROp_MakeArrayFromElement: + { + // It seems there are currently no array initializers in WGSL. + + // There is a constructor named like the array type itself + auto type = inst->getDataType(); + emitType(type); + m_writer->emit("("); + UInt argCount = + (UInt)cast( + cast(inst->getDataType())->getElementCount() + )->getValue(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } + break; + + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + // Structured buffers are just arrays in WGSL + auto base = inst->getOperand(0); + emitOperand(base, outerPrec); + m_writer->emit("["); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit("]"); + return true; + } + break; + + case kIROp_Rsh: + case kIROp_Lsh: + { + // Shift amounts must be an unsigned type in WGSL + // https://www.w3.org/TR/WGSL/#bit-expr + IRInst *const shiftAmount = inst->getOperand(1); + IRType *const shiftAmountType = shiftAmount->getDataType(); + if (shiftAmountType->getOp() == kIROp_IntType) + { + // Dawn complains about "mixing '<<' and '|' requires parenthesis", so let's + // add parenthesis. + m_writer->emit("("); + + const auto emitOp = getEmitOpForOp(inst->getOp()); + const auto info = getInfo(emitOp); + + const bool needClose = maybeEmitParens(outerPrec, info); + emitOperand(inst->getOperand(0), leftSide(outerPrec, info)); + m_writer->emit(" "); + m_writer->emit(info.op); + m_writer->emit(" "); + m_writer->emit("bitcast("); + emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); + m_writer->emit(")"); + maybeCloseParens(needClose); + + m_writer->emit(")"); + return true; + } + } + break; + + } + + return false; +} + +void WGSLSourceEmitter::emitVectorTypeNameImpl( + IRType* elementType, IRIntegerValue elementCount + ) +{ + + if (elementCount > 1) + { + m_writer->emit("vec"); + m_writer->emit(elementCount); + m_writer->emit("<"); + emitSimpleType(elementType); + m_writer->emit(">"); + } + else + { + emitSimpleType(elementType); + } +} + +void WGSLSourceEmitter::emitOperandImpl(IRInst* inst, const EmitOpInfo& outerPrec) +{ + // In WGSL, the structured buffer types are converted to ptr, AM> + // everywhere, except for the global parameter declaration. + // Thus, when these globals are used in expressions, we need an ampersand. + + if (inst->getOp() == kIROp_GlobalParam) + { + switch (inst->getDataType()->getOp()) + { + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + + m_writer->emit("(&"); + CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); + m_writer->emit(")"); + return; + } + } + + CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); +} + +void WGSLSourceEmitter::emitGlobalParamType(IRType* type, const String& name) +{ + // In WGSL, the structured buffer types are converted to ptr, AM> + // everywhere, except for the global parameter declaration. + + switch (type->getOp()) + { + + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + { + StringSliceLoc nameAndLoc(name.getUnownedSlice()); + NameDeclaratorInfo nameDeclarator(&nameAndLoc); + emitDeclarator(&nameDeclarator); + m_writer->emit(" : "); + auto structuredBufferType = as(type); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + } + break; + + default: + + emitType(type, name); + break; + + } + +} + +void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */) +{ + if (m_f16ExtensionEnabled) + { + m_writer->emit("enable f16;\n"); + m_writer->emit("\n"); + } +} + +void WGSLSourceEmitter::emitIntrinsicCallExprImpl( + IRCall* inst, + UnownedStringSlice intrinsicDefinition, + IRInst* intrinsicInst, + EmitOpInfo const& inOuterPrec + ) +{ + // The f16 constructor is generated for f32tof16 + if (intrinsicDefinition.startsWith("f16")) + { + m_f16ExtensionEnabled = true; + } + + CLikeSourceEmitter::emitIntrinsicCallExprImpl( + inst, intrinsicDefinition, intrinsicInst, inOuterPrec + ); +} + +} // namespace Slang diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h new file mode 100644 index 0000000000..d3cf19d91f --- /dev/null +++ b/source/slang/slang-emit-wgsl.h @@ -0,0 +1,78 @@ +#pragma once + +#include "slang-emit-c-like.h" + +namespace Slang +{ + +class WGSLSourceEmitter : public CLikeSourceEmitter +{ +public: + + WGSLSourceEmitter(const Desc& desc) + : CLikeSourceEmitter(desc) + {} + + virtual void emitParameterGroupImpl( + IRGlobalParam* varDecl, IRUniformParameterGroupType* type + ) SLANG_OVERRIDE; + virtual void emitEntryPointAttributesImpl( + IRFunc* irFunc, IREntryPointDecoration* entryPointDecor + ) SLANG_OVERRIDE; + virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; + virtual void emitVectorTypeNameImpl( + IRType* elementType, IRIntegerValue elementCount + ) SLANG_OVERRIDE; + virtual void emitFuncHeaderImpl(IRFunc* func) SLANG_OVERRIDE; + virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; + virtual bool tryEmitInstExprImpl( + IRInst* inst, const EmitOpInfo& inOuterPrec + ) SLANG_OVERRIDE; + virtual void emitSwitchCaseSelectorsImpl( + IRBasicType *const switchCondition, + const SwitchRegion::Case *const currentCase, + const bool isDefault + ) SLANG_OVERRIDE; + virtual void emitSimpleTypeAndDeclaratorImpl( + IRType* type, DeclaratorInfo* declarator + ) SLANG_OVERRIDE; + virtual void emitVarKeywordImpl(IRType * type, const bool isConstant) SLANG_OVERRIDE; + virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE; + virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE; + virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; + virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; + virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE; + virtual bool isPointerSyntaxRequiredImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE; + virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; + virtual void emitStructFieldAttributes( + IRStructType * structType, IRStructField * field + ) SLANG_OVERRIDE; + virtual void emitGlobalParamType(IRType* type, const String& name) SLANG_OVERRIDE; + virtual void emitOperandImpl( + IRInst* inst, const EmitOpInfo& outerPrec + ) SLANG_OVERRIDE; + + virtual void emitIntrinsicCallExprImpl( + IRCall* inst, + UnownedStringSlice intrinsicDefinition, + IRInst* intrinsicInst, + EmitOpInfo const& inOuterPrec + ) SLANG_OVERRIDE; + + void emit(const AddressSpace addressSpace); + +private: + + // Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns + void emitMatrixType( + IRType *const elementType, + const IRIntegerValue& rowCountWGSL, + const IRIntegerValue& colCountWGSL + ); + + bool m_f16ExtensionEnabled {false}; + +}; + +} // namespace Slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ed9e904627..cdd2ca5b66 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -31,6 +31,7 @@ #include "slang-ir-glsl-legalize.h" #include "slang-ir-hlsl-legalize.h" #include "slang-ir-metal-legalize.h" +#include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" #include "slang-ir-legalize-array-return-type.h" @@ -101,6 +102,7 @@ #include "slang-emit-glsl.h" #include "slang-emit-hlsl.h" #include "slang-emit-metal.h" +#include "slang-emit-wgsl.h" #include "slang-emit-cpp.h" #include "slang-emit-cuda.h" #include "slang-emit-torch.h" @@ -839,6 +841,10 @@ Result linkAndOptimizeIR( { simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); } + else if (requiredLoweringPassSet.generics) + { + eliminateDeadCode(irModule, fastIRSimplificationOptions.deadCodeElimOptions); + } if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc) && targetProgram->getOptionSet().shouldRunNonEssentialValidation()) @@ -1234,6 +1240,12 @@ Result linkAndOptimizeIR( } break; + case CodeGenTarget::WGSL: + { + legalizeIRForWGSL(irModule, sink); + } + break; + default: break; } @@ -1535,15 +1547,28 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr auto targetProgram = getTargetProgram(); auto lineDirectiveMode = targetProgram->getOptionSet().getEnumOption(CompilerOptionName::LineDirectiveMode); - // To try to make the default behavior reasonable, we will - // always use C-style line directives (to give the user - // good source locations on error messages from downstream - // compilers) *unless* they requested raw GLSL as the - // output (in which case we want to maximize compatibility - // with downstream tools). - if (lineDirectiveMode == LineDirectiveMode::Default && targetRequest->getTarget() == CodeGenTarget::GLSL) + // We will generally use C-style line directives in order to give the user good + // source locations on error messages from downstream compilers, but there are + // a few exceptions. + if (lineDirectiveMode == LineDirectiveMode::Default) { - lineDirectiveMode = LineDirectiveMode::GLSL; + + switch(targetRequest->getTarget()) + { + + case CodeGenTarget::GLSL: + // We want to maximize compatibility with downstream tools. + lineDirectiveMode = LineDirectiveMode::GLSL; + break; + + case CodeGenTarget::WGSL: + // WGSL doesn't support line directives. + // See https://github.com/gpuweb/gpuweb/issues/606. + lineDirectiveMode = LineDirectiveMode::None; + break; + + } + } ComPtr> sourceMap; @@ -1610,6 +1635,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr sourceEmitter = new MetalSourceEmitter(desc); break; } + case SourceLanguage::WGSL: + { + sourceEmitter = new WGSLSourceEmitter(desc); + break; + } default: break; } break; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index afc09f4801..b526df3a92 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -772,6 +772,11 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) needs to be special cased for lookup. */ INST(TransitoryDecoration, transitory, 0, 0) + // The result witness table that the functon's return type is a subtype of an interface. + // This is used to keep track of the original witness table in a function that used to + // return an existential value but now returns a concrete type after specialization. + INST(ResultWitnessDecoration, ResultWitness, 1, 0) + INST(VulkanRayPayloadDecoration, vulkanRayPayload, 0, 0) INST(VulkanRayPayloadInDecoration, vulkanRayPayloadIn, 0, 0) INST(VulkanHitAttributesDecoration, vulkanHitAttributes, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 1574c9e3de..f240e9ad8c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -799,6 +799,17 @@ struct IRSequentialIDDecoration : IRDecoration IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); } }; +struct IRResultWitnessDecoration : IRDecoration +{ + enum + { + kOp = kIROp_ResultWitnessDecoration + }; + IR_LEAF_ISA(ResultWitnessDecoration) + + IRInst* getWitness() { return getOperand(0); } +}; + struct IRDynamicDispatchWitnessDecoration : IRDecoration { IR_LEAF_ISA(DynamicDispatchWitnessDecoration) @@ -4541,6 +4552,11 @@ struct IRBuilder void addHighLevelDeclDecoration(IRInst* value, Decl* decl); + IRDecoration* addResultWitnessDecoration(IRInst* value, IRInst* witness) + { + return addDecoration(value, kIROp_ResultWitnessDecoration, witness); + } + IRDecoration* addTargetSystemValueDecoration(IRInst* value, UnownedStringSlice sysValName, UInt index = 0) { IRInst* operands[] = { getStringValue(sysValName), getIntValue(getIntType(), index)}; diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 5865d5320d..01b1c20dec 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1511,6 +1511,7 @@ static bool doesTargetAllowUnresolvedFuncSymbol(TargetRequest* req) case CodeGenTarget::Metal: case CodeGenTarget::MetalLib: case CodeGenTarget::MetalLibAssembly: + case CodeGenTarget::WGSL: case CodeGenTarget::DXIL: case CodeGenTarget::DXILAssembly: case CodeGenTarget::HostCPPSource: diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index a480ae6737..d0ad7483a4 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -888,16 +888,19 @@ namespace Slang IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) { - if (!isKhronosTarget(target->getTargetReq())) - return IRTypeLayoutRules::getNatural(); + if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL) + { + if (!isKhronosTarget(target->getTargetReq())) + return IRTypeLayoutRules::getNatural(); - // If we are just emitting GLSL, we can just use the general layout rule. - if (!target->shouldEmitSPIRVDirectly()) - return IRTypeLayoutRules::getNatural(); + // If we are just emitting GLSL, we can just use the general layout rule. + if (!target->shouldEmitSPIRVDirectly()) + return IRTypeLayoutRules::getNatural(); - // If the user specified a scalar buffer layout, then just use that. - if (target->getOptionSet().shouldUseScalarLayout()) - return IRTypeLayoutRules::getNatural(); + // If the user specified a scalar buffer layout, then just use that. + if (target->getOptionSet().shouldUseScalarLayout()) + return IRTypeLayoutRules::getNatural(); + } if (target->getOptionSet().shouldUseDXLayout()) { diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index a56dae0256..519c4b2602 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1204,6 +1204,38 @@ struct SpecializationContext return false; } + // Is the function's actual return type statically known? + // If so can we specialize the function even if it has no existential parameters. + // + bool isExistentialReturnTypeSpecializable(IRFunc* callee) + { + if (!as(callee->getResultType())) + return false; + + IRInst* witness = nullptr; + + for (auto block : callee->getBlocks()) + { + if (auto returnInst = as(block->getTerminator())) + { + if (auto makeExistential = as(returnInst->getVal())) + { + if (witness == nullptr) + witness = makeExistential->getWitnessTable(); + else if (witness != makeExistential->getWitnessTable()) + return false; + if (isChildInstOf(witness, callee)) + return false; + } + else + { + return false; + } + } + } + return true; + } + // Given a `call` instruction in the IR, we need to detect the case // where the callee has some interface-type parameter(s) and at the // call site it is statically clear what concrete type(s) the arguments @@ -1243,9 +1275,10 @@ struct SpecializationContext return false; // We shouldn't bother specializing unless the callee has at least - // one parameter that has an existential/interface type. + // one parameter/return type that has an existential/interface type. // - bool shouldSpecialize = false; + bool returnTypeNeedSpecialization = isExistentialReturnTypeSpecializable(calleeFunc); + bool argumentNeedSpecialization = false; UInt argCounter = 0; for (auto param : calleeFunc->getParams()) { @@ -1253,18 +1286,18 @@ struct SpecializationContext if (!isExistentialType(param->getDataType())) continue; - shouldSpecialize = true; - // We *cannot* specialize unless the argument value corresponding // to such a parameter is one we can specialize. // if (!canSpecializeExistentialArg(arg)) return false; + argumentNeedSpecialization = true; } - // If we never found a parameter worth specializing, we should bail out. + + // If we never found a parameter or return type worth specializing, we should bail out. // - if (!shouldSpecialize) + if (!returnTypeNeedSpecialization && !argumentNeedSpecialization) return false; // At this point, we believe we *should* and *can* specialize. @@ -1341,7 +1374,7 @@ struct SpecializationContext } else { - SLANG_UNEXPECTED("missing case for existential argument"); + SLANG_UNEXPECTED("unhandled existential argument"); } } @@ -1409,8 +1442,20 @@ struct SpecializationContext auto builder = &builderStorage; builder->setInsertBefore(inst); - auto newCall = builder->emitCallInst( - inst->getFullType(), specializedCallee, (UInt)newArgs.getCount(), newArgs.getArrayView().getBuffer()); + auto callResultType = specializedCallee->getResultType(); + IRInst* newCall = builder->emitCallInst( + callResultType, specializedCallee, (UInt)newArgs.getCount(), newArgs.getArrayView().getBuffer()); + + if (as(inst->getDataType())) + { + // If the result of the original call is specialized to a concrete type, + // we need to wrap it back into an existential type. + // + if (auto resultWitnessDecor = specializedCallee->findDecoration()) + { + newCall = builder->emitMakeExistential(inst->getDataType(), newCall, resultWitnessDecor->getWitness()); + } + } // We will completely replace the old `call` instruction with the // new one, and will go so far as to transfer any decorations @@ -1765,6 +1810,62 @@ struct SpecializationContext simplifyFunc(targetProgram, newFunc, IRSimplificationOptions::getFast(targetProgram)); + if (as(newFunc->getResultType())) + { + // If th result type is an interface type, and all return values are of the same + // concrete type, we can simplify the function to return the concrete type. + // We also need to mark the simplfiied function with a result witness decoration + // so we can rewrite all the callsites into IRMakeExistential using the witness. + // This is effectively pushing the MakeExistential to the call sites, so optimizations + // can happen across the function call boundaries. + IRInst* witnessTable = nullptr; + IRInst* concreteType = nullptr; + for (auto block : newFunc->getBlocks()) + { + if (auto returnInst = as(block->getTerminator())) + { + if (auto makeExistential = as(returnInst->getVal())) + { + if (!concreteType) + { + concreteType = makeExistential->getWrappedValue()->getDataType(); + witnessTable = makeExistential->getWitnessTable(); + } + else if (concreteType != makeExistential->getWrappedValue()->getDataType()) + { + concreteType = nullptr; + break; + } + if (isChildInstOf(witnessTable, newFunc)) + { + concreteType = nullptr; + break; + } + } + else + { + concreteType = nullptr; + break; + } + } + } + if (concreteType) + { + for (auto block : newFunc->getBlocks()) + { + if (auto returnInst = as(block->getTerminator())) + { + if (auto makeExistential = as(returnInst->getVal())) + { + returnInst->setOperand(0, makeExistential->getWrappedValue()); + } + } + } + builder->addResultWitnessDecoration(newFunc, witnessTable); + fixUpFuncType(newFunc, (IRType*)concreteType); + } + } + return newFunc; } @@ -2758,12 +2859,14 @@ void finalizeSpecialization(IRModule* module) break; case kIROp_StructKey: + case kIROp_Func: for (auto decor = inst->getFirstDecoration(); decor; ) { auto nextDecor = decor->getNextDecoration(); switch (decor->getOp()) { case kIROp_DispatchFuncDecoration: + case kIROp_ResultWitnessDecoration: decor->removeAndDeallocate(); break; default: diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp index 202d44e9d4..b5ce05895f 100644 --- a/source/slang/slang-ir-use-uninitialized-values.cpp +++ b/source/slang/slang-ir-use-uninitialized-values.cpp @@ -442,6 +442,12 @@ namespace Slang IRInst* user = use->getUser(); if (as(user)) return true; + + // Loading from a Ptr type should be + // treated as an aliased path to any return + IRLoad *load = as(user); + if (load && isReturnedValue(load)) + return true; } return false; } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index c4fc60bd22..f81cde30be 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1430,11 +1430,15 @@ HashSet getParentBreakBlockSet(IRDominatorTree* dom, IRBlock* block) currBlock = dom->getImmediateDominator(currBlock)) { if (auto loopInst = as(currBlock->getTerminator())) + { if (!dom->dominates(loopInst->getBreakBlock(), block)) parentBreakBlocksSet.add(loopInst->getBreakBlock()); + } else if (auto switchInst = as(currBlock->getTerminator())) + { if (!dom->dominates(switchInst->getBreakLabel(), block)) parentBreakBlocksSet.add(switchInst->getBreakLabel()); + } } return parentBreakBlocksSet; diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp new file mode 100644 index 0000000000..e05eba78c7 --- /dev/null +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -0,0 +1,347 @@ +#include "slang-ir-wgsl-legalize.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-parameter-binding.h" +#include "slang-ir-legalize-varying-params.h" + +namespace Slang +{ + + struct EntryPointInfo + { + IRFunc* entryPointFunc; + IREntryPointDecoration* entryPointDecor; + }; + + struct SystemValLegalizationWorkItem + { + IRInst* var; + String attrName; + UInt attrIndex; + }; + + struct WGSLSystemValueInfo + { + String wgslSystemValueName; + SystemValueSemanticName wgslSystemValueNameEnum; + ShortList permittedTypes; + bool isUnsupported = false; + }; + + struct LegalizeWGSLEntryPointContext + { + LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module) : + m_sink(sink), m_module(module) {} + + DiagnosticSink* m_sink; + IRModule* m_module; + + std::optional makeSystemValWorkItem(IRInst* var); + void legalizeSystemValue( + EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem + ); + List collectSystemValFromEntryPoint( + EntryPointInfo entryPoint + ); + void legalizeSystemValueParameters(EntryPointInfo entryPoint); + void legalizeEntryPointForWGSL(EntryPointInfo entryPoint); + IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType); + WGSLSystemValueInfo getSystemValueInfo( + String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar + ); + }; + + IRInst* LegalizeWGSLEntryPointContext::tryConvertValue( + IRBuilder& builder, IRInst* val, IRType* toType + ) + { + auto fromType = val->getFullType(); + if (auto fromVector = as(fromType)) + { + if (auto toVector = as(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = + builder.getVectorType( + fromVector->getElementType(), toVector->getElementCount() + ); + val = builder.emitVectorReshape(fromType, val); + } + } + else if (as(toType)) + { + UInt index = 0; + val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + } + else if (auto fromBasicType = as(fromType)) + { + if (fromBasicType->getOp() == kIROp_VoidType) + return nullptr; + if (!as(toType)) + return nullptr; + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + else + { + return nullptr; + } + return builder.emitCast(toType, val); + } + + + WGSLSystemValueInfo LegalizeWGSLEntryPointContext::getSystemValueInfo( + String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar + ) + { + IRBuilder builder(m_module); + WGSLSystemValueInfo result = {}; + UnownedStringSlice semanticName; + UnownedStringSlice semanticIndex; + + auto hasExplicitIndex = + splitNameAndIndex( + inSemanticName.getUnownedSlice(), semanticName, semanticIndex + ); + if (!hasExplicitIndex && optionalSemanticIndex) + semanticIndex = optionalSemanticIndex->getUnownedSlice(); + + result.wgslSystemValueNameEnum = + convertSystemValueSemanticNameToEnum(semanticName); + + switch (result.wgslSystemValueNameEnum) + { + + case SystemValueSemanticName::DispatchThreadID: + { + result.wgslSystemValueName = toSlice("global_invocation_id"); + IRType *const vec3uType { + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + }; + result.permittedTypes.add(vec3uType); + } + break; + + case SystemValueSemanticName::GroupID: + { + result.wgslSystemValueName = toSlice("workgroup_id"); + result.permittedTypes.add( + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + ); + } + break; + + case SystemValueSemanticName::GroupThreadID: + { + result.wgslSystemValueName = toSlice("local_invocation_id"); + result.permittedTypes.add( + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + ); + } + break; + + case SystemValueSemanticName::GSInstanceID: + { + // No Geometry shaders in WGSL + result.isUnsupported = true; + } + break; + + default: + { + m_sink->diagnose( + parentVar, + Diagnostics::unimplementedSystemValueSemantic, semanticName + ); + return result; + } + + } + + return result; + } + + std::optional + LegalizeWGSLEntryPointContext::makeSystemValWorkItem(IRInst* var) + { + if (auto semanticDecoration = var->findDecoration()) + { + bool svPrefix = + semanticDecoration->getSemanticName().startsWithCaseInsensitive( + toSlice("sv_") + ); + if (svPrefix) + { + return + { + { + var, + String(semanticDecoration->getSemanticName()).toLower(), + (UInt)semanticDecoration->getSemanticIndex() + } + }; + } + } + + auto layoutDecor = var->findDecoration(); + if (!layoutDecor) + return {}; + auto sysValAttr = layoutDecor->findAttr(); + if (!sysValAttr) + return {}; + auto semanticName = String(sysValAttr->getName()); + auto sysAttrIndex = sysValAttr->getIndex(); + + return { { var, semanticName, sysAttrIndex } }; + } + + List + LegalizeWGSLEntryPointContext::collectSystemValFromEntryPoint( + EntryPointInfo entryPoint + ) + { + List systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + auto maybeWorkItem = makeSystemValWorkItem(param); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + return systemValWorkItems; + } + + void + LegalizeWGSLEntryPointContext::legalizeSystemValue( + EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem + ) + { + IRBuilder builder(entryPoint.entryPointFunc); + + auto var = workItem.var; + auto semanticName = workItem.attrName; + + auto indexAsString = String(workItem.attrIndex); + auto info = getSystemValueInfo(semanticName, &indexAsString, var); + + if (!info.permittedTypes.getCount()) + return; + + builder.addTargetSystemValueDecoration( + var, info.wgslSystemValueName.getUnownedSlice() + ); + + bool varTypeIsPermitted = false; + auto varType = var->getFullType(); + for (auto& permittedType : info.permittedTypes) + { + varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + } + + if (!varTypeIsPermitted) + { + // Note: we do not currently prefer any conversion + // example: + // * allowed types for semantic: `float4`, `uint4`, `int4` + // * user used, `float2` + // * Slang will equally prefer `float4` to `uint4` to `int4`. + // This means the type may lose data if slang selects `uint4` or `int4`. + bool foundAConversion = false; + for (auto permittedType : info.permittedTypes) + { + var->setFullType(permittedType); + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst() + ); + + // get uses before we `tryConvertValue` since this creates a new use + List uses; + for (auto use = var->firstUse; use; use = use->nextUse) + uses.add(use); + + auto convertedValue = tryConvertValue(builder, var, varType); + if (convertedValue == nullptr) + continue; + + foundAConversion = true; + copyNameHintAndDebugDecorations(convertedValue, var); + + for (auto use : uses) + builder.replaceOperand(use, convertedValue); + } + if (!foundAConversion) + { + // If we can't convert the value, report an error. + for (auto permittedType : info.permittedTypes) + { + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, permittedType); + m_sink->diagnose( + var->sourceLoc, + Diagnostics::systemValueTypeIncompatible, + semanticName, + typeNameSB.produceString() + ); + } + } + } + } + + void LegalizeWGSLEntryPointContext::legalizeSystemValueParameters( + EntryPointInfo entryPoint + ) + { + List systemValWorkItems = + collectSystemValFromEntryPoint(entryPoint); + + for (auto index = 0; index < systemValWorkItems.getCount(); index++) + { + legalizeSystemValue(entryPoint, systemValWorkItems[index]); + } + } + + void LegalizeWGSLEntryPointContext::legalizeEntryPointForWGSL( + EntryPointInfo entryPoint + ) + { + legalizeSystemValueParameters(entryPoint); + } + + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) + { + List entryPoints; + for (auto inst : module->getGlobalInsts()) + { + IRFunc *const func {as(inst)}; + if (!func) + continue; + IREntryPointDecoration *const entryPointDecor = + func->findDecoration(); + if (!entryPointDecor) + continue; + EntryPointInfo info; + info.entryPointDecor = entryPointDecor; + info.entryPointFunc = func; + entryPoints.add(info); + } + + LegalizeWGSLEntryPointContext context(sink, module); + for (auto entryPoint : entryPoints) + context.legalizeEntryPointForWGSL(entryPoint); + } + +} diff --git a/source/slang/slang-ir-wgsl-legalize.h b/source/slang/slang-ir-wgsl-legalize.h new file mode 100644 index 0000000000..462f932044 --- /dev/null +++ b/source/slang/slang-ir-wgsl-legalize.h @@ -0,0 +1,10 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + class DiagnosticSink; + + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink); +} diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h index 04d4f5112c..178fbddd5e 100644 --- a/source/slang/slang-profile.h +++ b/source/slang/slang-profile.h @@ -19,6 +19,7 @@ namespace Slang CUDA = SLANG_SOURCE_LANGUAGE_CUDA, SPIRV = SLANG_SOURCE_LANGUAGE_SPIRV, Metal = SLANG_SOURCE_LANGUAGE_METAL, + WGSL = SLANG_SOURCE_LANGUAGE_WGSL, CountOf = SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index efa9a20a9e..38129babf5 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -797,9 +797,18 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti programLayout->getTargetReq()->getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); + auto astBuilder = program->getLinkage()->getASTBuilder(); try { auto result = program->findDeclFromString(name, &sink); + + if (auto genericDeclRef = result.as()) + { + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); + } + if (auto funcDeclRef = result.as()) return convert(funcDeclRef); } @@ -924,7 +933,7 @@ SLANG_API bool spReflection_isSubType( } } -SlangReflectionGeneric* getInnermostGenericParent(DeclRef declRef) +DeclRef getInnermostGenericParent(DeclRef declRef) { auto decl = declRef.getDecl(); auto astBuilder = getModule(decl)->getLinkage()->getASTBuilder(); @@ -932,15 +941,14 @@ SlangReflectionGeneric* getInnermostGenericParent(DeclRef declRef) while(parentDecl) { if(parentDecl->parentDecl && as(parentDecl->parentDecl)) - return convertDeclToGeneric( - substituteDeclRef( + return substituteDeclRef( SubstitutionSet(declRef), astBuilder, - createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl)))); + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl))); parentDecl = parentDecl->parentDecl; } - return nullptr; + return DeclRef(); } SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type) @@ -948,11 +956,13 @@ SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangRefl auto slangType = convert(type); if (auto declRefType = as(slangType)) { - return getInnermostGenericParent(declRefType->getDeclRef()); + return convertDeclToGeneric( + getInnermostGenericParent(declRefType->getDeclRef())); } else if (auto genericDeclRefType = as(slangType)) { - return getInnermostGenericParent(genericDeclRefType->getDeclRef()); + return convertDeclToGeneric( + getInnermostGenericParent(genericDeclRefType->getDeclRef())); } return nullptr; @@ -2835,7 +2845,7 @@ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inV SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var) { auto declRef = convert(var); - return getInnermostGenericParent(declRef); + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic) @@ -3072,7 +3082,7 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func) { auto declRef = convert(func); - return getInnermostGenericParent(declRef); + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic) @@ -3088,6 +3098,36 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(Sla return convert(substDeclRef.as()); } +SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( + SlangReflectionFunction* func, + SlangInt argTypeCount, + SlangReflectionType* const* argTypes) +{ + auto declRef = convert(func); + if (!declRef) + return nullptr; + + + auto linkage = getModule(declRef.getDecl())->getLinkage(); + + List argTypeList; + for (SlangInt ii = 0; ii < argTypeCount; ++ii) + { + auto argType = convert(argTypes[ii]); + argTypeList.add(argType); + } + + try + { + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as()); + } + catch (...) + { + return nullptr; + } +} + // Abstract decl reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl) @@ -3329,11 +3369,12 @@ SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(S auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder(); - return getInnermostGenericParent( - substituteDeclRef( - SubstitutionSet(declRef), - astBuilder, - createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl)))); + return convertDeclToGeneric( + getInnermostGenericParent( + substituteDeclRef( + SubstitutionSet(declRef), + astBuilder, + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl))))); } SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam) diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index f654135a14..2447f5787c 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1831,6 +1831,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe case CodeGenTarget::GLSL: case CodeGenTarget::SPIRV: case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::WGSL: return &kGLSLLayoutRulesFamilyImpl; case CodeGenTarget::HostHostCallable: @@ -2141,6 +2142,10 @@ SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* targetProgr { return SourceLanguage::Metal; } + case CodeGenTarget::WGSL: + { + return SourceLanguage::WGSL; + } case CodeGenTarget::CSource: { return SourceLanguage::C; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 91ed3de5fd..6c152cdddc 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -28,7 +28,6 @@ #include "slang-type-layout.h" #include "slang-lookup.h" -# #include "slang-options.h" #include "slang-repro.h" @@ -1069,8 +1068,12 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka for (const auto& nameToMod : builtinLinkage->mapNameToLoadedModules) mapNameToLoadedModules.add(nameToMod); } + + m_semanticsForReflection = new SharedSemanticsContext(this, nullptr, nullptr); } +SharedSemanticsContext* Linkage::getSemanticsForReflection() { return m_semanticsForReflection.get(); } + ISlangUnknown* Linkage::getInterface(const Guid& guid) { if(guid == ISlangUnknown::getTypeGuid() || guid == ISession::getTypeGuid()) @@ -1348,18 +1351,11 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( return asExternal(specializedType); } - -DeclRef Linkage::specializeGeneric( - DeclRef declRef, - List argExprs, - DiagnosticSink* sink) +DeclRef getGenericParentDeclRef( + ASTBuilder* astBuilder, + SemanticsVisitor* visitor, + DeclRef declRef) { - SLANG_AST_BUILDER_RAII(getASTBuilder()); - SLANG_ASSERT(declRef); - - SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink); - SemanticsVisitor visitor(&sharedSemanticsContext); - // Create substituted parent decl ref. auto decl = declRef.getDecl(); @@ -1369,9 +1365,58 @@ DeclRef Linkage::specializeGeneric( } auto genericDecl = as(decl); - auto genericDeclRef = createDefaultSubstitutionsIfNeeded(getASTBuilder(), &visitor, DeclRef(genericDecl)).as(); - genericDeclRef = substituteDeclRef(SubstitutionSet(declRef), getASTBuilder(), genericDeclRef).as(); + auto genericDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)).as(); + return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef).as(); +} + +DeclRef Linkage::specializeWithArgTypes( + DeclRef funcDeclRef, + List argTypes, + DiagnosticSink* sink) +{ + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(sink); + + ASTBuilder* astBuilder = getASTBuilder(); + + List argExprs; + for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa) + { + auto argType = argTypes[aa]; + + // Create an 'empty' expr with the given type. Ideally, the expression itself should not matter + // only its checked type. + // + auto argExpr = astBuilder->create(); + argExpr->type = argType; + argExprs.add(argExpr); + } + // Construct invoke expr. + auto invokeExpr = astBuilder->create(); + auto declRefExpr = astBuilder->create(); + + declRefExpr->declRef = getGenericParentDeclRef(getASTBuilder(), &visitor, funcDeclRef); + invokeExpr->functionExpr = declRefExpr; + invokeExpr->arguments = argExprs; + + auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr); + return as(as(checkedInvokeExpr)->functionExpr)->declRef; +} + + +DeclRef Linkage::specializeGeneric( + DeclRef declRef, + List argExprs, + DiagnosticSink* sink) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + SLANG_ASSERT(declRef); + + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(sink); + + auto genericDeclRef = getGenericParentDeclRef(getASTBuilder(), &visitor, declRef); DeclRefExpr* declRefExpr = getASTBuilder()->create(); declRefExpr->declRef = genericDeclRef; @@ -1561,8 +1606,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy try { - SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink); - SemanticsVisitor visitor(&sharedSemanticsContext); + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(&sink); + auto witness = visitor.isSubtype((Slang::Type*)type, (Slang::Type*)interfaceType, IsSubTypeOptions::None); if (auto subtypeWitness = as(witness)) @@ -1838,6 +1884,10 @@ CapabilitySet TargetRequest::getTargetCaps() atoms.add(CapabilityName::metal); break; + case CodeGenTarget::WGSL: + atoms.add(CapabilityName::wgsl); + break; + default: break; } @@ -2314,12 +2364,8 @@ DeclRef ComponentType::findDeclFromString( Expr* expr = linkage->parseTermString(name, scope); - SharedSemanticsContext sharedSemanticsContext( - linkage, - nullptr, - sink); - SemanticsContext context(&sharedSemanticsContext); - context = context.allowStaticReferenceToNonStaticMember(); + SemanticsContext context(linkage->getSemanticsForReflection()); + context = context.allowStaticReferenceToNonStaticMember().withSink(sink); SemanticsVisitor visitor(context); @@ -2373,12 +2419,8 @@ DeclRef ComponentType::findDeclFromStringInType( Expr* expr = linkage->parseTermString(name, scope); - SharedSemanticsContext sharedSemanticsContext( - linkage, - nullptr, - sink); - SemanticsContext context(&sharedSemanticsContext); - context = context.allowStaticReferenceToNonStaticMember(); + SemanticsContext context(linkage->getSemanticsForReflection()); + context = context.allowStaticReferenceToNonStaticMember().withSink(sink); SemanticsVisitor visitor(context); @@ -2429,11 +2471,7 @@ DeclRef ComponentType::findDeclFromStringInType( bool ComponentType::isSubType(Type* subType, Type* superType) { - SharedSemanticsContext sharedSemanticsContext( - getLinkage(), - nullptr, - nullptr); - SemanticsContext context(&sharedSemanticsContext); + SemanticsContext context(getLinkage()->getSemanticsForReflection()); SemanticsVisitor visitor(context); return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr); diff --git a/test-record-replay.sh b/test-record-replay.sh deleted file mode 100755 index 2faabe9549..0000000000 --- a/test-record-replay.sh +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env bash - -RED='\033[0;31m' -Green='\033[0;32m' -NC='\033[0m' -matchPattern="entrypoint: [0-9]+, target: [0-9]+, hash: [0-9a-fA-F]+" - -getHash() -{ - matchedLine=$1 - local -n outputVar=$2 - - entrypointIdx=$(echo $matchedLine | grep -oE "entrypoint: [0-9]+" | grep -oE "[0-9]+") - targetIdx=$(echo $matchedLine | grep -oE "target: [0-9]+" | grep -oE "[0-9]+") - hashCode=$(echo $matchedLine | grep -oE "hash: .*" | grep -oE ": [0-9a-fA-F]+" | grep -oE "[0-9a-fA-F]+") - - outputVar="$entrypointIdx-$targetIdx-$hashCode" -} - -log() -{ - msg=$1 - color=$2 - printf "${color}$1${NC}\n" -} - -parseStandardOutput() -{ - local -n resultArray=$1 - lines=$2 - - for line in "${lines[@]}" - do - matchLine=$(echo $line | grep -oE "$matchPattern") - - if [ -n "$matchLine" ]; then - result="" - getHash "$matchLine" result - - if [ -n "$result" ]; then - resultArray+=("$result") - fi - fi - done -} - -resultCheck() -{ - local -n inExpectedResults=$1 - local -n inReplayResults=$2 - local -n outFailedResults=$3 - - found="" - for expectedResult in ${inExpectedResults[@]}; do - - for replayResult in ${inReplayResults[@]}; do - if [ "$replayResult" == "$expectedResult" ]; then - found="1" - fi - done - - if [ -z "$found" ]; then - echo "$expectedResult is not Found in replay" - outFailedResults+=("$expectedResult") - else - echo "$expectedResult is Found in replay" - fi - done -} - -# TODO: Add more test commands here in this array -testCommands=("./build/Debug/bin/hello-world" "./build/Debug/bin/triangle") - -# Enable hash code generation for the test such that -# we can have something to compare with replaying the test -argsToEnableHashCode="--test-mode" - -declare -A testStats - -for ((i = 0; i < ${#testCommands[@]}; i++)) -do - testCommand=${testCommands[$i]} - echo "Start running test: $testCommand" - - # Run the test executable - export SLANG_RECORD_LAYER=1 - mapfile -t lines < <(${testCommand} ${argsToEnableHashCode}) - unset SLANG_RECORD_LAYER - - # parse the output from stdout - expectedResults=() - parseStandardOutput expectedResults "${lines[@]}" - - echo "Expected Results: ${expectedResults[@]}" - if [ ${#expectedResults[@]} -eq 0 ]; then - log "No expected results found" $RED - rm -rf ./slang-record/* - continue - fi - - # Replay the record file - export SLANG_RECORD_LOG_LEVEL=3 - replayTestCommand="./build/Debug/bin/slang-replay ./slang-record/*.cap" - echo "Start replaying the test ..." - mapfile -t lines < <(${replayTestCommand}) - unset SLANG_RECORD_LOG_LEVEL - - # parse the output from stdout - replayResults=() - parseStandardOutput replayResults "${lines[@]}" - - echo "Replay Results: ${replayResults[@]}" - if [ ${#replayResults[@]} -eq 0 ]; then - log "No replay results found" $RED - rm -rf ./slang-record/* - continue - fi - - # Check the results - failedResults=() - resultCheck expectedResults replayResults failedResults - - rm -rf ./slang-record/* - - if [ ${#failedResults[@]} -eq 0 ]; then - testStats[$testCommand]="PASSED" - else - testStats[$testCommand]="FAILED" - fi - printf "\n" -done - -for testName in "${!testStats[@]}" -do - if [ "${testStats[$testName]}" == "PASSED" ]; then - log "$testName: PASSED" $Green - else - log "$testName: FAILED" $RED - fi -done - -# Notify the CI if any of the tests failed -if [ ${#testStats[@]} -eq 0 ]; then - exit 0 -else - exit 1 -fi diff --git a/tests/bugs/gh-5026.slang b/tests/bugs/gh-5026.slang new file mode 100644 index 0000000000..1080d1180e --- /dev/null +++ b/tests/bugs/gh-5026.slang @@ -0,0 +1,23 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type + +// CHECK: type: int32_t +// CHECK-NEXT: 0 +// CHECK-NEXT: 0 +// CHECK-NEXT: 0 +// CHECK-NEXT: 0 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +T myMod(T x, T y) +{ + return x % y; +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int c = 1; + int d = 1; + outputBuffer[dispatchThreadID.x] = myMod(c,d); +} \ No newline at end of file diff --git a/tests/bugs/overload-ambiguous.slang b/tests/bugs/overload-ambiguous.slang new file mode 100644 index 0000000000..1b74cb68c2 --- /dev/null +++ b/tests/bugs/overload-ambiguous.slang @@ -0,0 +1,48 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + + +uint getData() +{ + return 1u; +} + +struct DataObtainer +{ + uint data; + uint getData() + { + return data; + } + + uint getValue() + { + return getData(); // will call DataObtainer::getData() + } + + uint getValue2() + { + return ::getData(); // will call global getData() + } +} + +RWStructuredBuffer output; + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + DataObtainer obtainer = {2u}; + outputBuffer[0] = obtainer.getValue(); + outputBuffer[1] = obtainer.getValue2(); + // BUF: 2 + // BUF-NEXT: 1 +} diff --git a/tests/diagnostics/uninitialized-struct-from-constructor.slang b/tests/diagnostics/uninitialized-struct-from-constructor.slang new file mode 100644 index 0000000000..e3c44dca17 --- /dev/null +++ b/tests/diagnostics/uninitialized-struct-from-constructor.slang @@ -0,0 +1,24 @@ +//TEST:SIMPLE(filecheck=CHK): -target spirv + +struct TangentSpace +{ + static const float3 localNormal = {0, 1, 0}; + + float4x4 tangentTransform; + float3 geometryNormal; + + __init(in float3 normal, in float3 inRay) + { + // Should not warn here + tangentTransform = getMatrix(normal, inRay); + geometryNormal = localNormal; + } + + float4x4 getMatrix(in float3 normal, in float3 inRay) + { + return float4x4(0.0f); + } +} + +//CHK-NOT: warning 41020 +//CHK-NOT: warning 41021 \ No newline at end of file diff --git a/tests/language-feature/generics/generic-return-type-requirement.slang b/tests/language-feature/generics/generic-return-type-requirement.slang new file mode 100644 index 0000000000..945afd808c --- /dev/null +++ b/tests/language-feature/generics/generic-return-type-requirement.slang @@ -0,0 +1,39 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=HLSL): -target hlsl -profile cs_6_0 -entry computeMain + +// HLSL-NOT: AnyValue + +interface IStack +{ + __generic + IStack popN(); + + int get(); +} +struct StackImpl : IStack +{ + // We should be able to specialize the callsites of this function to use + // the concrete type instead of resorting to dynamic dispatch. + __generic + IStack popN() { return StackImpl(); } + + int get() { return D; } +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +int test>(S stack) +{ + return stack.popN<2>().get(); +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + StackImpl<5> stack; + + // CHECK: 2 + outputBuffer[0] = test(stack); +} diff --git a/tests/language-feature/generics/nested-gen-value-param-inference-2.slang b/tests/language-feature/generics/nested-gen-value-param-inference-2.slang new file mode 100644 index 0000000000..038329e878 --- /dev/null +++ b/tests/language-feature/generics/nested-gen-value-param-inference-2.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +interface IValueGeneric {} +struct ValGenericImpl : IValueGeneric {} + +struct NestedValueGeneric> +{ + int x; +} + +void acceptor>(NestedValueGeneric x) +{ + outputBuffer[0] = D + x.x; +} + +extension> NestedValueGeneric +{ + void foo() + { + acceptor(this); + } +} +void test2(NestedValueGeneric<2, ValGenericImpl<2>> x) +{ + // 'foo' should be a member of 'NestedValueGeneric<2, ValGenericImpl<2>>' through + // the extension above. + x.foo(); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + NestedValueGeneric<2, ValGenericImpl<2>> x; + x.x = 1; + test2(x); + // CHECK: 3 +} \ No newline at end of file diff --git a/tests/language-feature/generics/nested-gen-value-param-inference.slang b/tests/language-feature/generics/nested-gen-value-param-inference.slang new file mode 100644 index 0000000000..cf7e5faf8d --- /dev/null +++ b/tests/language-feature/generics/nested-gen-value-param-inference.slang @@ -0,0 +1,32 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +interface IValueGeneric {} +struct ValGenericImpl : IValueGeneric {} + +struct NestedValueGeneric> +{ + int x; +} + +void acceptor>(NestedValueGeneric x) +{ + outputBuffer[0] = D + x.x; +} +void test(NestedValueGeneric<2, ValGenericImpl<2>> x) +{ + // Test that we can correctly infer acceptor.D and acceptor.S from `x`. + acceptor(x); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + NestedValueGeneric<2, ValGenericImpl<2>> x; + x.x = 1; + test(x); + // CHECK: 3 +} \ No newline at end of file diff --git a/tests/language-feature/interfaces/default-construct-conformance.slang b/tests/language-feature/interfaces/default-construct-conformance.slang index 00778a9f6a..4a6aed869a 100644 --- a/tests/language-feature/interfaces/default-construct-conformance.slang +++ b/tests/language-feature/interfaces/default-construct-conformance.slang @@ -36,22 +36,18 @@ struct TestAny : ITest uint getValue() { return value; } } -// CHECK: Tuple{{.*}} makeTest0{{.*}}() -// CHECK: Tuple{{.*}} = { uint2(0U, 0U), uint2(0U, 0U), packAnyValue4{{.*}} }; ITest makeTest0() { return Test0(); } -// CHECK: Tuple{{.*}} makeTest1{{.*}}() -// CHECK: Tuple{{.*}} = { uint2(0U, 0U), uint2(1U, 0U), packAnyValue4{{.*}} }; ITest makeTest1() { return Test1(); } -// CHECK: Tuple{{.*}} makeTestAny{{.*}}() -// CHECK: Tuple{{.*}} = { uint2(0U, 0U), uint2(2U, 0U), packAnyValue4{{.*}} }; +// CHECK: TestAny{{.*}} makeTestAny{{.*}}() +// CHECK: return TestAny_{{.*}}init{{.*}}() ITest makeTestAny() { return TestAny(); diff --git a/tests/language-feature/overloaded-subscript.slang b/tests/language-feature/overloaded-subscript.slang new file mode 100644 index 0000000000..f396f4f663 --- /dev/null +++ b/tests/language-feature/overloaded-subscript.slang @@ -0,0 +1,48 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +// Test that we can disambiguiate subscript decls by prefering the candidate that contains a super set of +// accessors than the other candidates. +interface IBuf +{ + T read(int x); +} +interface IRWBuf : IBuf +{ + [mutating] + void write(int x, T v); +} + +extension> U +{ + __subscript(int x) -> T { get { return read(x); } } +} + +extension> U +{ + __subscript(int x)->T { get { return read(x); } set { write(x, newValue); } } +} + +struct MyArray : IRWBuf +{ + T data[4]; + T read(int x) { return data[x]; } + [mutating] + void write(int x, T v) { data[x] = v; } +} + + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + MyArray arr = {}; + arr[0] = 1; + arr[1] = 2; + // CHECK: 1 + // CHECK: 2 + outputBuffer[0] = arr[0]; + outputBuffer[1] = arr[1]; +} diff --git a/tests/wgsl/math.slang b/tests/wgsl/math.slang new file mode 100644 index 0000000000..d7f39ce43c --- /dev/null +++ b/tests/wgsl/math.slang @@ -0,0 +1,279 @@ +//TEST:SIMPLE(filecheck=WGSL): -stage compute -entry computeMain -target wgsl + +RWStructuredBuffer inputBuffer; +RWStructuredBuffer outputBuffer; + +__generic +bool Test_Scalar() +{ + // WGSL-LABEL: Test_Scalar + + const T zero = T(inputBuffer[0]); + const T one = T(inputBuffer[1]); + const int zeroInt = int(inputBuffer[0]); + + T outFloat1, outFloat2; + int outInt; + + return true + + // WGSL: acos( + && zero == acos(one) + + // WGSL: acosh( + && zero == acosh(one) + + // WGSL: asin( + && zero == asin(zero) + + // WGSL: asinh( + && zero == asinh(zero) + + // WGSL: atan( + && zero == atan(zero) + + // WGSL: atan2( + && zero == atan2(zero, zero) + + // WGSL: atanh( + && zero == atanh(zero) + + // WGSL: ceil( + && zero == ceil(zero) + + // WGSL: cos( + && one == cos(zero) + + // WGSL: cosh( + && one == cosh(zero) + + // WGSL: exp( + && one == exp(zero) + + // WGSL: exp2( + && one == exp2(zero) + + // WGSL: abs( + && zero == abs(zero) + + // WGSL: floor( + && zero == floor(zero) + + // WGSL: fma( + && zero == fma(zero, zero, zero) + + // WGSL: max( + && zero == max(zero, zero) + + // WGSL: min( + && zero == min(zero, zero) + + // WGSL: fract( + && zero == fract(zero) + + // WGSL: frexp( + && zero == frexp(zero, outInt) && zeroInt == outInt + + // WGSL: ldexp( + && zero == ldexp(zero, zeroInt) + + // WGSL: log( + && zero == log(one) + + // WGSL: log2( + && zero == log2(one) + + // WGSL: modf( + && zero == modf(zero, outFloat1) + + // WGSL: pow( + && zero == pow(zero, one) + + // WGSL: round( + && zero == round(zero) + + // WGSL: sin( + && zero == sin(zero) + + // WGSL: sinh( + && zero == sinh(zero) + + // WGSL: sqrt( + && zero == sqrt(zero) + + // WGSL: tan( + && zero == tan(zero) + + // WGSL: tanh( + && zero == tanh(zero) + + // WGSL: trunc( + && zero == trunc(zero) + ; +} + +__generic +bool Test_Vector() +{ + // WGSL-LABEL: Test_Vector_0 + const vector zero = T(inputBuffer[0]); + const vector one = T(inputBuffer[1]); + + const vector zeroInt = int(inputBuffer[0]); + + vector outFloat1, outFloat2; + vector outInt; + + return true + // WGSL: acos( + // WGSL-NOT: acos( + && zero == acos(one) + + // WGSL: acosh( + // WGSL-NOT: acosh( + && zero == acosh(one) + + // WGSL: asin( + // WGSL-NOT: asin( + && zero == asin(zero) + + // WGSL: asinh( + // WGSL-NOT: asinh( + && zero == asinh(zero) + + // WGSL: atan( + // WGSL-NOT: atan( + && zero == atan(zero) + + // WGSL: atan2( + // WGSL-NOT: atan2( + && zero == atan2(zero, zero) + + // WGSL: atanh( + // WGSL-NOT: atanh( + && zero == atanh(zero) + + // WGSL: ceil( + // WGSL-NOT: ceil( + && zero == ceil(zero) + + // WGSL: cos( + // WGSL-NOT: cos( + && one == cos(zero) + + // WGSL: cosh( + // WGSL-NOT: cosh( + && one == cosh(zero) + + // WGSL: exp( + // WGSL-NOT: exp( + && one == exp(zero) + + // WGSL: exp2( + // WGSL-NOT: exp2( + && one == exp2(zero) + + // WGSL: abs( + // WGSL-NOT: abs( + && zero == abs(zero) + + // WGSL: floor( + // WGSL-NOT: floor( + && zero == floor(zero) + + // WGSL: fma( + // WGSL-NOT: fma( + && zero == fma(zero, zero, zero) + + // WGSL: max( + // WGSL-NOT: max( + && zero == max(zero, zero) + + // WGSL: min( + // WGSL-NOT: min( + && zero == min(zero, zero) + + // WGSL: fract( + // WGSL-NOT: fract( + && zero == fract(zero) + + // WGSL: frexp( + // WGSL-NOT: frexp( + && zero == frexp(zero, outInt) && all(zeroInt == outInt) + + // WGSL: ldexp( + // WGSL-NOT: ldexp( + && zero == ldexp(zero, zeroInt) + + // WGSL: log( + // WGSL-NOT: log( + && zero == log(one) + + // WGSL: log2( + // WGSL-NOT: log2( + && zero == log2(one) + + // WGSL: modf( + // WGSL-NOT: modf( + && zero == modf(zero, outFloat1) + + // WGSL: pow( + // WGSL-NOT: pow( + && zero == pow(zero, one) + + // WGSL: round( + // WGSL-NOT: round( + && zero == round(zero) + + // WGSL: sin( + // WGSL-NOT: sin( + && zero == sin(zero) + + // WGSL: sinh( + // WGSL-NOT: sinh( + && zero == sinh(zero) + + // WGSL: sqrt( + // WGSL-NOT: sqrt( + && zero == sqrt(zero) + + // WGSL: tan( + // WGSL-NOT: tan( + && zero == tan(zero) + + // WGSL: tanh( + // WGSL-NOT: tanh( + && zero == tanh(zero) + + // WGSL: trunc( + // WGSL-NOT: trunc( + && zero == trunc(zero) + ; + + // WGSL-LABEL: Test_Vector_1 +} + +[numthreads(1,1,1)] +void computeMain() +{ + // GLSL: void main( + // GLSL_SPIRV: OpEntryPoint + // SPIR: OpEntryPoint + // HLSL: void computeMain( + // CUDA: void computeMain( + // CPP: void _computeMain( + + bool result = true + && Test_Scalar() + && Test_Vector() + && Test_Vector() + && Test_Vector() + && Test_Scalar() + && Test_Vector() + && Test_Vector() + && Test_Vector() + ; + + // BUF: 1 + outputBuffer[0] = int(result); +} diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index b49d7be73b..cac694f777 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -993,6 +993,11 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target) return PassThroughFlag::Metal; } + case SLANG_WGSL: + { + return PassThroughFlag::WGSL; + } + case SLANG_SHADER_HOST_CALLABLE: case SLANG_HOST_HOST_CALLABLE: diff --git a/tools/slang-test/test-context.h b/tools/slang-test/test-context.h index 28d39b064c..314ec2803a 100644 --- a/tools/slang-test/test-context.h +++ b/tools/slang-test/test-context.h @@ -37,7 +37,8 @@ struct PassThroughFlag Generic_C_CPP = 1 << int(SLANG_PASS_THROUGH_GENERIC_C_CPP), NVRTC = 1 << int(SLANG_PASS_THROUGH_NVRTC), LLVM = 1 << int(SLANG_PASS_THROUGH_LLVM), - Metal = 1 << int(SLANG_PASS_THROUGH_METAL) + Metal = 1 << int(SLANG_PASS_THROUGH_METAL), + WGSL = 1 << int(SLANG_PASS_THROUGH_WGSL) }; }; diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index a6121469ff..b580ff8fe3 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -64,12 +64,16 @@ SLANG_UNIT_TEST(declTreeReflection) struct MyGenericType { T z; + + __init(T _z) { z = _z; } + T g() { return z; } U h(U x, out T y) { y = z; return x; } T j(T x, out int o) { o = N; return x; } - } + U q(U x, T y) { return x; } + } namespace MyNamespace { @@ -79,6 +83,8 @@ SLANG_UNIT_TEST(declTreeReflection) } } + T foo(T t, U u) { return t; } + )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -110,7 +116,7 @@ SLANG_UNIT_TEST(declTreeReflection) auto moduleDeclReflection = module->getModuleReflection(); SLANG_CHECK(moduleDeclReflection != nullptr); SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module); - SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8); + SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 9); // First declaration should be a struct with 1 variable, 1 constructor (memberwise ctor), 1 funcDecl ($ZeroInit) auto firstDecl = moduleDeclReflection->getChild(0); @@ -379,6 +385,59 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false); } + // Check specializeWithArgTypes() + { + auto unspecializedFoo = compositeProgram->getLayout()->findFunctionByName("foo"); + SLANG_CHECK(unspecializedFoo != nullptr); + + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + auto uintType = compositeProgram->getLayout()->findTypeByName("uint"); + SLANG_CHECK(uintType != nullptr); + + List argTypes; + argTypes.add(floatType); + argTypes.add(uintType); + + slang::FunctionReflection* specializedFoo = unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer()); + SLANG_CHECK(specializedFoo != nullptr); + + SLANG_CHECK(getTypeFullName(specializedFoo->getReturnType()) == "float"); + SLANG_CHECK(specializedFoo->getParameterCount() == 2); + + SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(0)->getName()) == "t"); + SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(0)->getType()) == "float"); + + SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(1)->getName()) == "u"); + SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(1)->getType()) == "uint"); + } + + // Check specializeArgTypes on member method looked up through a specialized type + { + auto specializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); + SLANG_CHECK(specializedType != nullptr); + + auto unspecializedMethod = compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h"); + SLANG_CHECK(unspecializedMethod != nullptr); + + // Specialize the method with float + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + + auto halfType = compositeProgram->getLayout()->findTypeByName("half"); + SLANG_CHECK(halfType != nullptr); + + List argTypes; + argTypes.add(floatType); + argTypes.add(halfType); + + auto specializedMethodWithFloat = unspecializedMethod->specializeWithArgTypes( + argTypes.getCount(), + argTypes.getBuffer()); + SLANG_CHECK(specializedMethodWithFloat != nullptr); + SLANG_CHECK(getTypeFullName(specializedMethodWithFloat->getReturnType()) == "float"); + } + // Check iterators { unsigned int count = 0; @@ -386,7 +445,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 8); + SLANG_CHECK(count == 9); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind()) @@ -407,7 +466,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 1); + SLANG_CHECK(count == 2); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind())