diff --git a/Project.toml b/Project.toml index 5632003c77..b569c5e2e1 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8" [compat] Adapt = "3" @@ -50,6 +51,7 @@ RootSolvers = "0.3, 0.4" Static = "0.4, 0.5, 0.6, 0.7, 0.8" StaticArrays = "1" UnPack = "1" +Unrolled = "0.1" julia = "1.8" [extras] diff --git a/benchmarks/bickleyjet/Manifest.toml b/benchmarks/bickleyjet/Manifest.toml index ab4f51f250..394c5c061e 100644 --- a/benchmarks/bickleyjet/Manifest.toml +++ b/benchmarks/bickleyjet/Manifest.toml @@ -157,10 +157,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = "../.." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.10.53" +version = "0.10.54" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -1235,6 +1235,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index f0838997b9..2cbb4f3485 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -274,10 +274,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.10.53" +version = "0.10.54" [[deps.ClimaCoreMakie]] deps = ["ClimaCore", "Makie"] @@ -2480,6 +2480,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 78ac4519bd..f57d12b48e 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -238,10 +238,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.10.53" +version = "0.10.54" [[deps.ClimaCorePlots]] deps = ["ClimaCore", "RecipesBase", "StaticArrays", "TriplotBase"] @@ -2019,6 +2019,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/perf/Manifest.toml b/perf/Manifest.toml index e65f98ec5e..9e14fe2928 100644 --- a/perf/Manifest.toml +++ b/perf/Manifest.toml @@ -217,10 +217,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.10.53" +version = "0.10.54" [[deps.ClimaCorePlots]] deps = ["ClimaCore", "RecipesBase", "StaticArrays", "TriplotBase"] @@ -2077,6 +2077,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 957c75de91..0f59f8d9cd 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -49,6 +49,9 @@ import StaticArrays: SMatrix, SVector import BandedMatrices: BandedMatrix, band, _BandedMatrix import ClimaComms +import Unrolled: unrolled_foreach, unrolled_map, unrolled_reduce +import Unrolled: unrolled_in, unrolled_any, unrolled_all, unrolled_filter + import ..Utilities: PlusHalf, half import ..RecursiveApply: rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index f027b1c1e1..a9bfbee87f 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -118,41 +118,42 @@ struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <: subtrees::S end -FieldNameTree(x) = make_subtree_at_name(x, @name()) -function make_subtree_at_name(x, name) +FieldNameTree(x) = subtree_at_name(x, @name()) +function subtree_at_name(x, name) internal_names = top_level_names(get_field(x, name)) - isempty(internal_names) && return FieldNameTreeLeaf(name) - subsubtrees = unrolled_map(internal_names) do internal_name - make_subtree_at_name(x, append_internal_name(name, internal_name)) + return if isempty(internal_names) + FieldNameTreeLeaf(name) + else + FieldNameTreeNode(name, subtrees_at_names(x, name, internal_names)) end - return FieldNameTreeNode(name, subsubtrees) end - -is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name -is_valid_name(name, tree::FieldNameTreeNode) = +subtrees_at_names(x, name, internal_names) = + isempty(internal_names) ? () : + ( + subtree_at_name(x, append_internal_name(name, internal_names[1])), + subtrees_at_names(x, name, internal_names[2:end])..., + ) + +is_valid_name(name, tree) = name == tree.name || - is_child_name(name, tree.name) && - unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees) + tree isa FieldNameTreeNode && is_valid_name(name, tree.subtrees...) +is_valid_name(name, subtree, subtrees...) = + is_valid_name(name, subtree) || is_valid_name(name, subtrees...) function child_names(name, tree) + is_valid_name(name, tree) || error("$name is not a valid name") subtree = get_subtree_at_name(name, tree) - subtree isa FieldNameTreeNode || - error("FieldNameTree does not contain any child names for $name") + subtree isa FieldNameTreeNode || error("$name does not have child names") return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees) end -get_subtree_at_name(name, tree::FieldNameTreeLeaf) = - name == tree.name ? tree : - error("FieldNameTree does not contain the name $name") -get_subtree_at_name(name, tree::FieldNameTreeNode) = +get_subtree_at_name(name, tree) = if name == tree.name tree - elseif is_valid_name(name, tree) - subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree - is_child_name(name, subtree.name) - end - get_subtree_at_name(name, subtree_that_contains_name) else - error("FieldNameTree does not contain the name $name") + subtree = unrolled_findonly(tree.subtrees) do subtree + is_valid_name(name, subtree) + end + get_subtree_at_name(name, subtree) end ################################################################################ @@ -175,7 +176,10 @@ if hasfield(Method, :recursion_relation) for m in methods(wrapped_prop_names) m.recursion_relation = dont_limit end - for m in methods(make_subtree_at_name) + for m in methods(subtree_at_name) + m.recursion_relation = dont_limit + end + for m in methods(subtrees_at_names) m.recursion_relation = dont_limit end for m in methods(is_valid_name) diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index ed26b40822..4a31cc56ac 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -66,8 +66,8 @@ const FieldMatrixBroadcasted = FieldNameDict{ dict_type(::FieldNameDict{T1, T2}) where {T1, T2} = FieldNameDict{T1, T2} function Base.show(io::IO, dict::FieldNameDict) - strings = map((key, value) -> " $key => $value", pairs(dict)) - print(io, "$(dict_type(dict))($(join(strings, ",\n")))") + strings = map(pair -> "\n$(pair[1]) => $(pair[2])", pairs(dict)) + print(io, "$(dict_type(dict))($(join(strings, ","))\n)") end Base.keys(dict::FieldNameDict) = dict.keys @@ -75,9 +75,7 @@ Base.keys(dict::FieldNameDict) = dict.keys Base.values(dict::FieldNameDict) = dict.entries Base.pairs(dict::FieldNameDict) = - unrolled_map(unrolled_zip(keys(dict).values, values(dict))) do key_entry_tup - key_entry_tup[1] => key_entry_tup[2] - end + unrolled_map((key, value) -> key => value, keys(dict).values, values(dict)) Base.length(dict::FieldNameDict) = length(keys(dict)) @@ -118,18 +116,19 @@ function get_internal_entry( # See note above matrix_product_keys in field_name_set.jl for more details. T = eltype(eltype(entry)) if name_pair == (@name(), @name()) - # multiplication case 1, either argument entry - elseif broadcasted_has_field(T, name_pair[1]) && name_pair[2] == @name() + elseif name_pair[1] == name_pair[2] + # multiplication case 3 or 4, first argument + @assert T <: SingleValue && !broadcasted_has_field(T, name_pair[1]) + entry + elseif name_pair[2] == @name() # multiplication case 2 or 4, second argument + @assert broadcasted_has_field(T, name_pair[1]) Base.broadcasted(entry) do matrix_row map(matrix_row) do matrix_row_entry broadcasted_get_field(matrix_row_entry, name_pair[1]) end end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. - elseif T <: SingleValue && name_pair[1] == name_pair[2] - # multiplication case 3 or 4, first argument - entry else unsupported_internal_entry_error(entry, name_pair) end diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index 1f28960bff..a7743d79b4 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -79,46 +79,36 @@ end Base.:(==)(set1::FieldNameSet, set2::FieldNameSet) = issubset(set1, set2) && issubset(set2, set1) -function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} +function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) values1′, values2′ = set1.values, set2.values values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - result_values = unrolled_filter(values2) do value - unrolled_any(isequal(value), values1) - end - return FieldNameSet{T}(result_values, name_tree) + return FieldNameSet{T}(unrolled_union(values1, values2), name_tree) end -function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} +function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) values1′, values2′ = set1.values, set2.values values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - values2_minus_values1 = unrolled_filter(values2) do value - !unrolled_any(isequal(value), values1) - end - result_values = (values1..., values2_minus_values1...) - return FieldNameSet{T}(result_values, name_tree) + return FieldNameSet{T}(unrolled_intersect(values1, values2), name_tree) end function Base.setdiff(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - set2_complement_values = set_complement_values(T, set2.values, name_tree) - set2_complement = FieldNameSet{T}(set2_complement_values, name_tree) - return intersect(set1, set2_complement) + values1′, values2′ = set1.values, set2.values + values1, values2 = non_overlapping_values(values1′, values2′, name_tree) + return FieldNameSet{T}(unrolled_setdiff(values1, values2), name_tree) end set_string(set) = length(set) == 2 ? join(set.values, " and ") : join(set.values, ", ", ", and ") +set_complement(set) = setdiff(universal_set(eltype(set)), set) + is_subset_that_covers_set(set1, set2) = issubset(set1, set2) && isempty(setdiff(set2, set1)) -function set_complement(set::FieldNameSet{T}) where {T} - result_values = set_complement_values(T, set.values, set.name_tree) - return FieldNameSet{T}(result_values, set.name_tree) -end - function corresponding_matrix_keys(set::FieldVectorKeys) result_values = unrolled_map(name -> (name, name), set.values) return FieldMatrixKeys(result_values, set.name_tree) @@ -150,9 +140,13 @@ function matrix_diagonal_keys(set::FieldMatrixKeys) names_are_overlapping(name_pair[1], name_pair[2]) end result_values = unrolled_map(result_values′) do name_pair - name_pair[1] == name_pair[2] ? name_pair : - is_child_value(name_pair[1], name_pair[2]) ? - (name_pair[1], name_pair[1]) : (name_pair[2], name_pair[2]) + if name_pair[1] == name_pair[2] + name_pair + elseif is_child_value(name_pair[1], name_pair[2]) + (name_pair[1], name_pair[1]) + else + (name_pair[2], name_pair[2]) + end end return FieldMatrixKeys(result_values, set.name_tree) end @@ -302,7 +296,7 @@ is_child_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = is_child_name(name_pair1[2], name_pair2[2]) is_value_in_set(value, values, name_tree) = - if unrolled_any(isequal(value), values) + if unrolled_in(value, values) true elseif unrolled_any(value′ -> is_child_value(value, value′), values) isnothing(name_tree) && error( @@ -313,6 +307,11 @@ is_value_in_set(value, values, name_tree) = false end +universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),)) +universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),)) + +# TODO: Simplify the following code. + function non_overlapping_values(values1, values2, name_tree) new_values1 = unrolled_mapflatten(values1) do value value_or_non_overlapping_children(value, values2, name_tree) @@ -381,62 +380,13 @@ function value_or_non_overlapping_children( end end -set_complement_values(_, _, ::Nothing) = - error("Cannot compute complement of a FieldNameSet without a FieldNameTree") -set_complement_values(::Type{<:FieldName}, names, name_tree::FieldNameTree) = - complement_values_in_subtree(names, name_tree) -set_complement_values( - ::Type{<:FieldNamePair}, - name_pairs, - name_tree::FieldNameTree, -) = complement_values_in_subtree_pair(name_pairs, (name_tree, name_tree)) - -function complement_values_in_subtree(names, subtree) - name = subtree.name - unrolled_all(name′ -> !is_child_value(name, name′), names) || return () - unrolled_any(name′ -> is_child_value(name′, name), names) || return (name,) - return unrolled_mapflatten(subtree.subtrees) do subsubtree - complement_values_in_subtree(names, subsubtree) - end -end - -function complement_values_in_subtree_pair(name_pairs, subtree_pair) - name_pair = (subtree_pair[1].name, subtree_pair[2].name) - is_name_pair_in_complement = unrolled_all(name_pairs) do name_pair′ - !is_child_value(name_pair, name_pair′) - end - is_name_pair_in_complement || return () - need_row_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] - end - need_col_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] - end - need_row_subsubtrees || need_col_subsubtrees || return (name_pair,) - row_subsubtrees = - need_row_subsubtrees ? subtree_pair[1].subtrees : (subtree_pair[1],) - col_subsubtrees = - need_col_subsubtrees ? subtree_pair[2].subtrees : (subtree_pair[2],) - return unrolled_mapflatten(row_subsubtrees) do row_subsubtree - unrolled_mapflatten(col_subsubtrees) do col_subsubtree - subsubtree_pair = (row_subsubtree, col_subsubtree) - complement_values_in_subtree_pair(name_pairs, subsubtree_pair) - end - end -end - -################################################################################ - # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(value_or_non_overlapping_children) - m.recursion_relation = dont_limit - end - for m in methods(complement_values_in_subtree) + for m in methods(unrolled_mapflatten) m.recursion_relation = dont_limit end - for m in methods(complement_values_in_subtree_pair) + for m in methods(value_or_non_overlapping_children) m.recursion_relation = dont_limit end end diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl index 947a45084d..533aaf7785 100644 --- a/src/MatrixFields/unrolled_functions.jl +++ b/src/MatrixFields/unrolled_functions.jl @@ -1,70 +1,17 @@ -@inline unrolled_zip(values1, values2) = - isempty(values1) || isempty(values2) ? () : - ( - (first(values1), first(values2)), - unrolled_zip(Base.tail(values1), Base.tail(values2))..., - ) - -@inline unrolled_map(f::F, values) where {F} = - isempty(values) ? () : - (f(first(values)), unrolled_map(f, Base.tail(values))...) - -unrolled_foldl(f::F, values) where {F} = - isempty(values) ? - error("unrolled_foldl requires init for an empty collection of values") : - _unrolled_foldl(f, first(values), Base.tail(values)) -unrolled_foldl(f::F, values, init) where {F} = _unrolled_foldl(f, init, values) -@inline _unrolled_foldl(f::F, result, values) where {F} = - isempty(values) ? result : - _unrolled_foldl(f, f(result, first(values)), Base.tail(values)) - -# The @inline annotations are needed to avoid allocations when there are a lot -# of values. - -# Using first and tail instead of [1] and [2:end] restricts us to Tuples, but it -# also results in less compilation time. - -# This is required to make the unrolled functions type-stable, as of Julia 1.9. -if hasfield(Method, :recursion_relation) - dont_limit = (args...) -> true - for m in methods(unrolled_zip) - m.recursion_relation = dont_limit - end - for m in methods(unrolled_map) - m.recursion_relation = dont_limit - end - for m in methods(_unrolled_foldl) - m.recursion_relation = dont_limit - end -end - -################################################################################ - -unrolled_foreach(f::F, values) where {F} = (unrolled_map(f, values); nothing) - -unrolled_any(f::F, values) where {F} = - unrolled_foldl(|, unrolled_map(f, values), false) - -unrolled_all(f::F, values) where {F} = - unrolled_foldl(&, unrolled_map(f, values), true) - -unrolled_filter(f::F, values) where {F} = - unrolled_foldl(values, ()) do filtered_values, value - f(value) ? (filtered_values..., value) : filtered_values - end +# These functions are also defined in Unrolled.jl, but those versions use "in" +# instead of "unrolled_in". +unrolled_union(values1, values2) = + (values1..., unrolled_setdiff(values2, values1)...) +unrolled_intersect(values1, values2) = + unrolled_filter(x -> unrolled_in(x, values2), values1) +unrolled_setdiff(values1, values2) = + unrolled_filter(x -> !unrolled_in(x, values2), values1) unrolled_unique(values) = - unrolled_foldl(values, ()) do unique_values, value - unrolled_any(isequal(value), unique_values) ? unique_values : - (unique_values..., value) - end + unrolled_reduce(unrolled_union, (), unrolled_map(tuple, values)) unrolled_flatten(values) = - unrolled_foldl(values, ()) do flattened_values, value - (flattened_values..., value...) - end - -# Non-standard functions: + unrolled_reduce((tuple1, tuple2) -> (tuple1..., tuple2...), (), values) unrolled_mapflatten(f::F, values) where {F} = unrolled_flatten(unrolled_map(f, values)) @@ -73,19 +20,12 @@ function unrolled_findonly(f::F, values) where {F} filtered_values = unrolled_filter(f, values) length(filtered_values) == 1 || error("unrolled_findonly requires that exactly one value makes f true") - return first(filtered_values) + return filtered_values[1] end -# This is required to make functions defined elsewhere type-stable, as of Julia -# 1.9. Specifically, if an unrolled function is used to implement the recursion -# of another function, it needs to have its recursion limit disabled in order -# for that other function to be type-stable. -if hasfield(Method, :recursion_relation) - dont_limit = (args...) -> true - for m in methods(unrolled_any) - m.recursion_relation = dont_limit - end # for is_valid_name - for m in methods(unrolled_mapflatten) - m.recursion_relation = dont_limit - end # for complement_values_in_subtree and value_or_non_overlapping_children -end +# The way unrolled_reduce is defined in Unrolled.jl essentially makes it a +# backwards unrolled_foldl. +unrolled_foldl(f::F, values) where {F} = + isempty(values) ? + error("unrolled_foldl requires a nonempty collection of values") : + unrolled_reduce((x, y) -> f(y, x), values[1], values[2:end]) diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 6834a9eef1..7724951e07 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -94,11 +94,11 @@ end (@name(a.b), @name(a.c)) @test_all MatrixFields.child_names(@name(a.c), name_tree) == (@name(a.c.:(1)), @name(a.c.:(2)), @name(a.c.:(3))) - @test_throws "does not contain any child names" MatrixFields.child_names( + @test_throws "does not have child names" MatrixFields.child_names( @name(a.c.:(2)), name_tree, ) - @test_throws "does not contain the name" MatrixFields.child_names( + @test_throws "is not a valid name" MatrixFields.child_names( @name(foo.invalid_name), name_tree, ) @@ -244,9 +244,11 @@ end (v_set1, v_set2), (v_set1, v_set2_no_tree), (v_set1_no_tree, v_set2), + (v_set1_no_tree, v_set2_no_tree), (m_set1, m_set2), (m_set1, m_set2_no_tree), (m_set1_no_tree, m_set2), + (m_set1_no_tree, m_set2_no_tree), ) @test_all set1 != set2 @test_all !issubset(set1, set2) @@ -257,20 +259,6 @@ end @test_all !MatrixFields.is_subset_that_covers_set(set2, set1) end - for (set1, set2) in - ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) - @test_all set1 != set2 - @test_all !issubset(set1, set2) - @test_all issubset(set2, set1) - @test_all intersect(set1, set2) == set2 - @test_all union(set1, set2) == set1 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) - @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( - set2, - set1, - ) - end - for (set1, set3) in ( (v_set1, v_set3), (v_set1, v_set3_no_tree),