diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 4a31cc56ac..f540f1a0e9 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -248,7 +248,7 @@ Base.Broadcast.broadcasted( arg3, args..., ) = - unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′ + foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′ Base.Broadcast.broadcasted(f, arg1′, arg2′) end diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index 3dd16e52de..76841e9d80 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -106,7 +106,7 @@ end set_string(set) = values_string(set.values) -set_complement(set) = setdiff(universal_set(eltype(set)), set) +set_complement(set) = setdiff(universal_set(set.name_tree, eltype(set)), set) is_subset_that_covers_set(set1, set2) = issubset(set1, set2) && isempty(setdiff(set2, set1)) @@ -116,9 +116,11 @@ function corresponding_matrix_keys(set::FieldVectorKeys) return FieldMatrixKeys(result_values, set.name_tree) end -function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys) - name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - result_values = unrolled_product((set1.values, set2.values)) +function cartesian_product(row_set::FieldVectorKeys, col_set::FieldVectorKeys) + name_tree = combine_name_trees(row_set.name_tree, col_set.name_tree) + result_values = unrolled_flatmap(row_set.values) do row_name + unrolled_map(col_name -> (row_name, col_name), col_set.values) + end return FieldMatrixKeys(result_values, name_tree) end @@ -278,8 +280,16 @@ combine_name_trees(name_tree1, name_tree2) = error("Mismatched FieldNameTrees: The ability to combine different \ FieldNameTrees has not been implemented") -universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),)) -universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),)) +universal_set(::Nothing, ::Type{FieldName}) = error( + "Missing FieldNameTree: Cannot compute complement of FieldNameSet without \ + a FieldNameTree", +) +universal_set(name_tree, ::Type{FieldName}) = + FieldVectorKeys(child_names(@name(), name_tree), name_tree) +function universal_set(name_tree, ::Type{FieldNamePair}) + row_set = universal_set(name_tree, FieldName) + return cartesian_product(row_set, row_set) +end is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree) is_valid_value(name_pair::FieldNamePair, name_tree) = @@ -329,49 +339,32 @@ function union_values(values, new_values, name_tree) and $(values_string(overlapping_values)) without a FieldNameTree", ) - overlapping_values_that_are_children_of_value, other_overlapping_values = + overlapping_children_of_new_value, other_overlapping_values = unrolled_split(overlapping_values) do value is_child_value(value, new_value) end - possible_union_values = if isempty(other_overlapping_values) - possible_children_of_value = available_sets_of_child_values( - new_value, - overlapping_values, - name_tree, - ) - recursively_unrolled_map( - possible_children_of_value, - ) do children_of_value - union_values( - values, - (children_of_value..., more_new_values...), - name_tree, - ) - end + return if isempty(other_overlapping_values) + children_of_new_value = + expand_along_row_or_column(new_value, overlapping_values, name_tree) + expanded_new_values = (children_of_new_value..., more_new_values...) + union_values(values, expanded_new_values, name_tree) else - possible_children_of_other_overlapping_values = unrolled_map( - unrolled_flatten, - unrolled_prodmap(other_overlapping_values) do value - available_sets_of_child_values(value, (new_value,), name_tree) - end, + children_of_other_overlapping_values = + unrolled_flatmap(other_overlapping_values) do value + expand_along_row_or_column(value, (new_value,), name_tree) + end + expanded_values = ( + non_overlapping_values..., + overlapping_children_of_new_value..., + children_of_other_overlapping_values..., ) - recursively_unrolled_map( - possible_children_of_other_overlapping_values, - ) do children_of_other_overlapping_values - values_and_children_of_values = ( - non_overlapping_values..., - overlapping_values_that_are_children_of_value..., - children_of_other_overlapping_values..., - ) - union_values(values_and_children_of_values, new_values, name_tree) - end + union_values(expanded_values, new_values, name_tree) end - return unrolled_argmin(length, possible_union_values) end -available_sets_of_child_values(name::FieldName, _, name_tree) = - (child_names(name, name_tree),) -function available_sets_of_child_values( +expand_along_row_or_column(name::FieldName, _, name_tree) = + child_names(name, name_tree) +function expand_along_row_or_column( name_pair::FieldNamePair, overlapping_name_pairs, name_tree, @@ -397,19 +390,19 @@ function available_sets_of_child_values( n_col_children = col_children_needed ? length(col_children) : 0 # We are guaranteed that either n_row_children > 0 or n_col_children > 0. return if n_row_children == n_col_children == 1 - only_child_of_name_pair = (row_children[1][1], col_children[1][2]) - ((only_child_of_name_pair,),) + ((row_children[1][1], col_children[1][2]),) elseif n_row_children > 0 && n_col_children <= 1 - (row_children,) + row_children elseif n_col_children > 0 && n_row_children <= 1 - (col_children,) - else - @assert n_row_children > 1 && n_col_children > 1 - # If multiple row and column children are available, we might get - # different results depending on whether we expand the row or the column - # first, so we need to try both and pick the one that corresponds to the - # result with the shortest length. - (row_children, col_children) + col_children + else # n_row_children > 1 && n_col_children > 1 + # If multiple row and column children are needed, the result of + # union_values only depends on whether we expand the row or the column + # first when name_pair == (@name(), @name()). In this function, we make + # the arbitrary choice to expand the row first. Since we do not expect + # users to construct matrices that contain (@name(), @name()), this + # choice should not have any noticeable effect. + row_children end end diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl index 5442856008..9f01989168 100644 --- a/src/MatrixFields/unrolled_functions.jl +++ b/src/MatrixFields/unrolled_functions.jl @@ -1,8 +1,13 @@ -# The following functions are extensions to the generated functions defined in -# Unrolled.jl. +# The following functions build upon the generated functions in Unrolled.jl. + +unrolled_flatten(values) = + unrolled_reduce((tuple1, tuple2) -> (tuple1..., tuple2...), (), values) + +unrolled_flatmap(f::F, values) where {F} = + unrolled_flatten(unrolled_map(f, values)) unrolled_split(f::F, values) where {F} = - unrolled_filter(f, values), unrolled_filter(!f, values) + (unrolled_filter(f, values), unrolled_filter(!f, values)) function unrolled_findonly(f::F, values) where {F} filtered_values = unrolled_filter(f, values) @@ -10,49 +15,6 @@ function unrolled_findonly(f::F, values) where {F} error("unrolled_findonly requires that exactly 1 value makes f true") end -# The implementation of unrolled_reduce in Unrolled.jl makes it roughly the same -# as foldl, but with the order of arguments in every call to f flipped. -unrolled_foldl(f::F, values; init = nothing) where {F} = - if isnothing(init) - isempty(values) ? - error("unrolled_foldl requires init for empty collections of values") : - unrolled_reduce((x, y) -> f(y, x), values[1], values[2:end]) - else - unrolled_reduce((x, y) -> f(y, x), init, values) - end - -function unrolled_argmin(f::F, values) where {F} - values_and_fs = unrolled_map(value -> (value, f(value)), values) - min_value_and_f = - unrolled_foldl(values_and_fs) do (min_value, min_f), (new_value, new_f) - new_f < min_f ? (new_value, new_f) : (min_value, min_f) - end - return min_value_and_f[1] -end - -# This needs to use unrolled_reduce instead of unrolled_foldl in order for -# unrolled_product to be type-stable (otherwise, calling unrolled_flatmap inside -# of unrolled_foldl causes the compiler to hit a recursion limit). -unrolled_flatten(values) = - unrolled_reduce((), values) do value, flattened_values - (flattened_values..., value...) - end - -unrolled_flatmap(f::F, values) where {F} = - unrolled_flatten(unrolled_map(f, values)) - -unrolled_product(values) = - unrolled_foldl(values; init = ((),)) do product_values, value - unrolled_flatmap(product_values) do sub_values - unrolled_map(value) do sub_value - (sub_values..., sub_value) - end - end - end - -unrolled_prodmap(f::F, values) where {F} = - unrolled_product(unrolled_map(f, values)) - # The following functions are recursion-based alternatives to the generated # functions in Unrolled.jl. These should be used instead of their generated # counterparts when implementing recursion in other functions. For example, if a @@ -64,8 +26,9 @@ unrolled_prodmap(f::F, values) where {F} = isempty(values) ? () : (f(values[1]), recursively_unrolled_map(f, values[2:end])...) -recursively_unrolled_any(f::F, values) where {F} = - unrolled_any(identity, recursively_unrolled_map(f, values)) +@inline recursively_unrolled_any(f::F, values) where {F} = + isempty(values) ? false : + f(values[1]) || recursively_unrolled_any(f, values[2:end]) # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 46b56ee069..c18ef0bbbb 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -75,9 +75,9 @@ const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) (@name(1), @name(2), @name(3)) end -@testset "FieldNameTree Unit Tests" begin - name_tree = MatrixFields.FieldNameTree(x) +const name_tree = MatrixFields.FieldNameTree(x) +@testset "FieldNameTree Unit Tests" begin @test_all MatrixFields.FieldNameTree(x) == name_tree @test_all MatrixFields.is_valid_name(@name(), name_tree) @@ -104,16 +104,18 @@ end ) end -@testset "FieldNameSet Unit Tests" begin - name_tree = MatrixFields.FieldNameTree(x) - vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) - matrix_keys(name_pairs...) = - MatrixFields.FieldMatrixKeys(name_pairs, name_tree) +vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) +matrix_keys(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs, name_tree) - vector_keys_no_tree(names...) = MatrixFields.FieldVectorKeys(names) - matrix_keys_no_tree(name_pairs...) = - MatrixFields.FieldMatrixKeys(name_pairs) +vector_keys_no_tree(names...) = MatrixFields.FieldVectorKeys(names) +matrix_keys_no_tree(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs) +drop_tree(set::MatrixFields.FieldVectorKeys) = + MatrixFields.FieldVectorKeys(set.values) +drop_tree(set::MatrixFields.FieldMatrixKeys) = + MatrixFields.FieldMatrixKeys(set.values) + +@testset "FieldNameSet Unit Tests" begin @testset "FieldNameSet Constructors" begin @test_throws "Invalid FieldNameSet value" vector_keys( @name(foo.invalid_name), @@ -146,16 +148,11 @@ end end v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = - matrix_keys_no_tree((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) # Proper subsets of v_set1 and m_set1. v_set2 = vector_keys(@name(foo)) - v_set2_no_tree = vector_keys_no_tree(@name(foo)) m_set2 = matrix_keys((@name(foo), @name(a.c))) - m_set2_no_tree = matrix_keys_no_tree((@name(foo), @name(a.c))) # Subsets that cover v_set1 and m_set1. v_set3 = vector_keys( @@ -164,32 +161,18 @@ end @name(a.c.:(2)), @name(a.c.:(3)), ) - v_set3_no_tree = vector_keys_no_tree( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), - ) m_set3 = matrix_keys( (@name(foo.value), @name(a.c.:(1))), (@name(foo), @name(a.c.:(2))), (@name(foo), @name(a.c.:(3))), (@name(a.b), @name(foo.value)), ) - m_set3_no_tree = matrix_keys_no_tree( - (@name(foo.value), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - ) # Sets that overlap with v_set1 and m_set1, but are neither subsets nor # supersets of those sets. Some of the values in m_set4 overlap with # those in m_set1, but are neither children nor parents of those values # (this is only possible with matrix keys). v_set4 = vector_keys(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) - v_set4_no_tree = - vector_keys_no_tree(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) m_set4 = matrix_keys( (@name(), @name(a.c.:(1))), (@name(foo.value), @name(foo)), @@ -197,22 +180,15 @@ end (@name(a), @name(foo.value)), (@name(a.c.:(3)), @name(a.c.:(3))), ) - m_set4_no_tree = matrix_keys_no_tree( - (@name(), @name(a.c.:(1))), - (@name(foo.value), @name(foo)), - (@name(foo.value), @name(a.c.:(2))), - (@name(a), @name(foo.value)), - (@name(a.c.:(3)), @name(a.c.:(3))), - ) @testset "FieldNameSet Basic Operations" begin @test string(v_set1) == "FieldVectorKeys(@name(foo), @name(a.c); )" - @test string(v_set1_no_tree) == + @test string(drop_tree(v_set1)) == "FieldVectorKeys(@name(foo), @name(a.c))" @test string(m_set1) == "FieldMatrixKeys((@name(foo), @name(a.c)), \ (@name(a.b), @name(foo)); )" - @test string(m_set1_no_tree) == "FieldMatrixKeys((@name(foo), \ + @test string(drop_tree(m_set1)) == "FieldMatrixKeys((@name(foo), \ @name(a.c)), (@name(a.b), @name(foo)))" @test_all map(name -> (name, name), v_set1) == @@ -223,12 +199,12 @@ end @test_all isnothing(foreach(name -> (name, name), v_set1)) @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) - for set in (v_set1, v_set1_no_tree) + for set in (v_set1, drop_tree(v_set1)) @test_all @name(foo) in set @test_all !(@name(a.b) in set) @test_all !(@name(invalid_name) in set) end - for set in (m_set1, m_set1_no_tree) + for set in (m_set1, drop_tree(m_set1)) @test_all (@name(foo), @name(a.c)) in set @test_all !((@name(foo), @name(a.b)) in set) @test_all !((@name(foo), @name(invalid_name)) in set) @@ -236,62 +212,71 @@ end @test_all @name(foo.value) in v_set1 @test_all !(@name(foo.invalid_name) in v_set1) - @test_throws "FieldNameTree" @name(foo.value) in v_set1_no_tree - @test_throws "FieldNameTree" @name(foo.invalid_name) in v_set1_no_tree + @test_throws "FieldNameTree" @name(foo.value) in drop_tree(v_set1) + @test_throws "FieldNameTree" @name(foo.invalid_name) in + drop_tree(v_set1) @test_all (@name(foo.value), @name(a.c)) in m_set1 @test_all !((@name(foo.invalid_name), @name(a.c)) in m_set1) @test_throws "FieldNameTree" (@name(foo.value), @name(a.c)) in - m_set1_no_tree + drop_tree(m_set1) @test_throws "FieldNameTree" (@name(foo.invalid_name), @name(a.c)) in - m_set1_no_tree + drop_tree(m_set1) + end + @testset "FieldNameSet Complement Sets" begin @test_all MatrixFields.set_complement(v_set1) == vector_keys_no_tree(@name(a.b)) - @test_throws "FieldNameTree" MatrixFields.set_complement(v_set1_no_tree) + @test_all MatrixFields.set_complement(v_set2) == + vector_keys_no_tree(@name(a)) + @test_all MatrixFields.set_complement(v_set3) == + vector_keys_no_tree(@name(a.b)) + @test_all MatrixFields.set_complement(v_set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) @test_all MatrixFields.set_complement(m_set1) == matrix_keys_no_tree( (@name(foo), @name(foo)), (@name(foo), @name(a.b)), - (@name(a.c), @name()), - (@name(a.b), @name(a)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), + ) + @test_all MatrixFields.set_complement(m_set2) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(foo)), + (@name(a), @name(a)), + ) + @test_all MatrixFields.set_complement(m_set3) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), ) - @test_throws "FieldNameTree" MatrixFields.set_complement(m_set1_no_tree) - - @test_all MatrixFields.set_complement(v_set4) == - vector_keys_no_tree(@name(foo), @name(a.c.:(3))) - @test_throws "FieldNameTree" MatrixFields.set_complement(v_set4_no_tree) - @test_all MatrixFields.set_complement(m_set4) == matrix_keys_no_tree( - (@name(), @name(a.b)), + (@name(foo), @name(a.b)), (@name(foo), @name(a.c.:(3))), + (@name(a), @name(a.b)), (@name(a), @name(a.c.:(2))), (@name(a.b), @name(a.c.:(3))), (@name(a.c.:(1)), @name(a.c.:(3))), (@name(a.c.:(2)), @name(a.c.:(3))), ) - @test_throws "FieldNameTree" MatrixFields.set_complement(m_set4_no_tree) - end - @testset "FieldNameSet Binary Set Operations" begin - for set1 in (v_set1, v_set1_no_tree, m_set1, m_set1_no_tree) - @test_all set1 == set1 - @test_all issubset(set1, set1) - @test_all is_subset_that_covers_set(set1, set1) - @test_all intersect(set1, set1) == set1 - @test_all union(set1, set1) == set1 - @test_all isempty(setdiff(set1, set1)) + for set in (drop_tree(v_set1), drop_tree(m_set1)) + @test_throws "FieldNameTree" MatrixFields.set_complement(set) end + end + @testset "FieldNameSet Binary Set Operations" begin for (set1, set2) in ( (v_set1, v_set2), - (v_set1, v_set2_no_tree), - (v_set1_no_tree, v_set2), - (v_set1_no_tree, v_set2_no_tree), + (v_set1, drop_tree(v_set2)), + (drop_tree(v_set1), v_set2), + (drop_tree(v_set1), drop_tree(v_set2)), (m_set1, m_set2), - (m_set1, m_set2_no_tree), - (m_set1_no_tree, m_set2), - (m_set1_no_tree, m_set2_no_tree), + (m_set1, drop_tree(m_set2)), + (drop_tree(m_set1), m_set2), + (drop_tree(m_set1), drop_tree(m_set2)), ) @test_all set1 != set2 @test_all !issubset(set1, set2) && issubset(set2, set1) @@ -310,11 +295,11 @@ end for (set1, set3) in ( (v_set1, v_set3), - (v_set1, v_set3_no_tree), - (v_set1_no_tree, v_set3), + (v_set1, drop_tree(v_set3)), + (drop_tree(v_set1), v_set3), (m_set1, m_set3), - (m_set1, m_set3_no_tree), - (m_set1_no_tree, m_set3), + (m_set1, drop_tree(m_set3)), + (drop_tree(m_set1), m_set3), ) @test_all set1 != set3 @test_all !issubset(set1, set3) && issubset(set3, set1) @@ -326,8 +311,10 @@ end isempty(setdiff(set3, set1)) end - for (set1, set3) in - ((v_set1_no_tree, v_set3_no_tree), (m_set1_no_tree, m_set3_no_tree)) + for (set1, set3) in ( + (drop_tree(v_set1), drop_tree(v_set3)), + (drop_tree(m_set1), drop_tree(m_set3)), + ) @test_all set1 != set3 @test_all !issubset(set1, set3) @test_all !is_subset_that_covers_set(set1, set3) @@ -343,11 +330,11 @@ end for (set1, set4) in ( (v_set1, v_set4), - (v_set1, v_set4_no_tree), - (v_set1_no_tree, v_set4), + (v_set1, drop_tree(v_set4)), + (drop_tree(v_set1), v_set4), (m_set1, m_set4), - (m_set1, m_set4_no_tree), - (m_set1_no_tree, m_set4), + (m_set1, drop_tree(m_set4)), + (drop_tree(m_set1), m_set4), ) @test_all set1 != set4 @test_all !issubset(set1, set4) && !issubset(set4, set1) @@ -400,8 +387,10 @@ end end end - for (set1, set4) in - ((v_set1_no_tree, v_set4_no_tree), (m_set1_no_tree, m_set4_no_tree)) + for (set1, set4) in ( + (drop_tree(v_set1), drop_tree(v_set4)), + (drop_tree(m_set1), drop_tree(m_set4)), + ) @test_all set1 != set4 @test_all !issubset(set1, set4) && !issubset(set4, set1) @test_all !is_subset_that_covers_set(set1, set4) && @@ -607,15 +596,15 @@ end # With one exception, none of the following operations require a # FieldNameTree. - @test_all MatrixFields.corresponding_matrix_keys(v_set1_no_tree) == + @test_all MatrixFields.corresponding_matrix_keys(drop_tree(v_set1)) == matrix_keys_no_tree( (@name(foo), @name(foo)), (@name(a.c), @name(a.c)), ) @test_all MatrixFields.cartesian_product( - v_set1_no_tree, - v_set4_no_tree, + drop_tree(v_set1), + drop_tree(v_set4), ) == matrix_keys_no_tree( (@name(foo), @name(a.b)), (@name(foo), @name(a.c.:(1))), @@ -625,7 +614,7 @@ end (@name(a.c), @name(a.c.:(2))), ) - @test_all MatrixFields.matrix_row_keys(m_set1_no_tree) == + @test_all MatrixFields.matrix_row_keys(drop_tree(m_set1)) == vector_keys_no_tree(@name(foo), @name(a.b)) @test_all MatrixFields.matrix_row_keys(m_set4) == vector_keys_no_tree( @@ -636,10 +625,10 @@ end @name(a.c.:(3)) ) @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( - m_set4_no_tree, + drop_tree(m_set4), ) - @test_all MatrixFields.matrix_off_diagonal_keys(m_set4_no_tree) == + @test_all MatrixFields.matrix_off_diagonal_keys(drop_tree(m_set4)) == matrix_keys_no_tree( (@name(), @name(a.c.:(1))), (@name(foo.value), @name(foo)), @@ -647,7 +636,7 @@ end (@name(a), @name(foo.value)), ) - @test_all MatrixFields.matrix_diagonal_keys(m_set4_no_tree) == + @test_all MatrixFields.matrix_diagonal_keys(drop_tree(m_set4)) == matrix_keys_no_tree( (@name(foo.value), @name(foo.value)), (@name(a.c.:(1)), @name(a.c.:(1))),