Skip to content

Commit

Permalink
Simplify set complements
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Nov 1, 2023
1 parent ddd1d8a commit bf6c14a
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 240 deletions.
8 changes: 5 additions & 3 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ is_child_name(
::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
length(child_name_chain) >= length(parent_name_chain) &&
child_name_chain[1:length(parent_name_chain)] == parent_name_chain
unrolled_take(child_name_chain, Val(length(parent_name_chain))) ==
parent_name_chain

is_overlapping_name(name1, name2) =
is_child_name(name1, name2) || is_child_name(name2, name1)
Expand All @@ -83,8 +84,9 @@ extract_internal_name(
parent_name::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
is_child_name(child_name, parent_name) ?
FieldName(child_name_chain[(length(parent_name_chain) + 1):end]...) :
error("$child_name is not a child name of $parent_name")
FieldName(
unrolled_drop(child_name_chain, Val(length(parent_name_chain)))...,
) : error("$child_name is not a child name of $parent_name")

append_internal_name(
::FieldName{name_chain},
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ Base.Broadcast.broadcasted(
arg3,
args...,
) =
unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
Base.Broadcast.broadcasted(f, arg1′, arg2′)
end

Expand Down
221 changes: 121 additions & 100 deletions src/MatrixFields/field_name_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ end

set_string(set) = values_string(set.values)

set_complement(set) = setdiff(universal_set(eltype(set)), set)
set_complement(set) = setdiff(universal_set(set.name_tree, eltype(set)), set)

is_subset_that_covers_set(set1, set2) =
issubset(set1, set2) && isempty(setdiff(set2, set1))
Expand All @@ -116,9 +116,9 @@ function corresponding_matrix_keys(set::FieldVectorKeys)
return FieldMatrixKeys(result_values, set.name_tree)
end

function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys)
name_tree = combine_name_trees(set1.name_tree, set2.name_tree)
result_values = unrolled_product((set1.values, set2.values))
function cartesian_product(row_set::FieldVectorKeys, col_set::FieldVectorKeys)
name_tree = combine_name_trees(row_set.name_tree, col_set.name_tree)
result_values = unrolled_product(row_set.values, col_set.values)
return FieldMatrixKeys(result_values, name_tree)
end

Expand Down Expand Up @@ -253,10 +253,10 @@ check_values(values, name_tree) =
(isnothing(name_tree) || is_valid_value(value, name_tree)) || error(
"Invalid FieldNameSet value: $value is incompatible with name_tree",
)
duplicate_values = unrolled_filter(isequal(value), values)
length(duplicate_values) == 1 || error(
"Duplicate FieldNameSet values: $(length(duplicate_values)) copies \
of $value have been passed to a FieldNameSet constructor",
num_duplicate_values = length(unrolled_filter(isequal(value), values))
num_duplicate_values == 1 || error(
"Duplicate FieldNameSet values: $num_duplicate_values copies of \
$value have been passed to a FieldNameSet constructor",
)
overlapping_values = unrolled_filter(values) do value′
value′ != value && is_overlapping_value(value, value′)
Expand All @@ -278,8 +278,16 @@ combine_name_trees(name_tree1, name_tree2) =
error("Mismatched FieldNameTrees: The ability to combine different \
FieldNameTrees has not been implemented")

universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),))
universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),))
universal_set(::Nothing, ::Type{FieldName}) = error(
"Missing FieldNameTree: Cannot compute complement of FieldNameSet without \
a FieldNameTree",
)
universal_set(name_tree, ::Type{FieldName}) =
FieldVectorKeys(child_names(@name(), name_tree), name_tree)
function universal_set(name_tree, ::Type{FieldNamePair})
row_set = universal_set(name_tree, FieldName)
return cartesian_product(row_set, row_set)
end

is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree)
is_valid_value(name_pair::FieldNamePair, name_tree) =
Expand Down Expand Up @@ -307,115 +315,128 @@ is_value_in_set(value, values, name_tree) =
) : is_valid_value(value, name_tree)
)

remove_duplicates_and_overlaps(values, name_tree) =
union_values((), values, name_tree)

function union_values(values, new_values, name_tree)
isempty(new_values) && return values

new_value = new_values[1]
more_new_values = new_values[2:end]
unrolled_in(new_value, values) &&
return union_values(values, more_new_values, name_tree)

overlapping_values, non_overlapping_values = unrolled_split(values) do value
is_overlapping_value(new_value, value)
function remove_duplicates_and_overlaps(values, name_tree)
unique_values = unrolled_unique(values)
overlapping_values, non_overlapping_values =
unrolled_split(unique_values) do value
unrolled_any(unique_values) do value′
value != value′ && is_overlapping_value(value, value′)
end
end
isempty(overlapping_values) && return unique_values
isnothing(name_tree) &&
error("Missing FieldNameTree: Cannot eliminate overlaps among \
$(values_string(overlapping_values)) without a FieldNameTree")
expanded_overlapping_values = unrolled_flatmap(overlapping_values) do value
values_overlapping_with_value =
unrolled_filter(overlapping_values) do value′
value != value′ && is_overlapping_value(value, value′)
end
expand_child_values(value, values_overlapping_with_value, name_tree)
end
isempty(overlapping_values) &&
return union_values((values..., new_value), more_new_values, name_tree)
no_longer_overlapping_values =
remove_duplicates_and_overlaps(expanded_overlapping_values, name_tree)
return (non_overlapping_values..., no_longer_overlapping_values...)
end

# The function union_values(values1, values2, name_tree) generates the same
# result as remove_duplicates_and_overlaps((values1..., values2...), name_tree),
# but it is slightly more efficient (and faster to compile) because it makes use
# of the fact that values1 == remove_duplicates_and_overlaps(values1, name_tree)
# and values2 == remove_duplicates_and_overlaps(values2, name_tree).
function union_values(values1, values2, name_tree)
unique_values2 =
unrolled_filter(value2 -> !unrolled_in(value2, values1), values2)
overlapping_values1, non_overlapping_values1 =
unrolled_split(values1) do value1
unrolled_any(unique_values2) do value2
is_overlapping_value(value1, value2)
end
end
isempty(overlapping_values1) && return (values1..., unique_values2...)
overlapping_values2, non_overlapping_values2 =
unrolled_split(unique_values2) do value2
unrolled_any(values1) do value1
is_overlapping_value(value1, value2)
end
end
isnothing(name_tree) && error(
"Missing FieldNameTree: Cannot eliminate overlaps between $new_value \
and $(values_string(overlapping_values)) without a FieldNameTree",
"Missing FieldNameTree: Cannot eliminate overlaps between \
$overlapping_values1 and $overlapping_values2 without a FieldNameTree",
)

overlapping_values_that_are_children_of_value, other_overlapping_values =
unrolled_split(overlapping_values) do value
is_child_value(value, new_value)
end
possible_union_values = if isempty(other_overlapping_values)
possible_children_of_value = available_sets_of_child_values(
new_value,
overlapping_values,
name_tree,
)
recursively_unrolled_map(
possible_children_of_value,
) do children_of_value
union_values(
values,
(children_of_value..., more_new_values...),
name_tree,
)
expanded_overlapping_values1 =
unrolled_flatmap(overlapping_values1) do value1
values2_overlapping_value1 =
unrolled_filter(overlapping_values2) do value2
is_overlapping_value(value1, value2)
end
expand_child_values(value1, values2_overlapping_value1, name_tree)
end
else
possible_children_of_other_overlapping_values = unrolled_map(
unrolled_flatten,
unrolled_prodmap(other_overlapping_values) do value
available_sets_of_child_values(value, (new_value,), name_tree)
end,
)
recursively_unrolled_map(
possible_children_of_other_overlapping_values,
) do children_of_other_overlapping_values
values_and_children_of_values = (
non_overlapping_values...,
overlapping_values_that_are_children_of_value...,
children_of_other_overlapping_values...,
)
union_values(values_and_children_of_values, new_values, name_tree)
expanded_overlapping_values2 =
unrolled_flatmap(overlapping_values2) do value2
values1_overlapping_value2 =
unrolled_filter(overlapping_values1) do value1
is_overlapping_value(value1, value2)
end
expand_child_values(value2, values1_overlapping_value2, name_tree)
end
end
return unrolled_argmin(length, possible_union_values)
union_of_overlapping_values = union_values(
expanded_overlapping_values1,
expanded_overlapping_values2,
name_tree,
)
return (
non_overlapping_values1...,
non_overlapping_values2...,
union_of_overlapping_values...,
)
end

available_sets_of_child_values(name::FieldName, _, name_tree) =
(child_names(name, name_tree),)
function available_sets_of_child_values(
expand_child_values(name::FieldName, overlapping_names, name_tree) =
unrolled_all(overlapping_names) do name′
name′ != name && is_child_name(name′, name)
end ? child_names(name, name_tree) : (name,)
function expand_child_values(
name_pair::FieldNamePair,
overlapping_name_pairs,
name_tree,
)
row_name, col_name = name_pair
row_children_needed = unrolled_any(overlapping_name_pairs) do name_pair′
name_pair′[1] != row_name && is_child_name(name_pair′[1], row_name)
end
col_children_needed = unrolled_any(overlapping_name_pairs) do name_pair′
name_pair′[2] != col_name && is_child_name(name_pair′[2], col_name)
end
row_children =
row_children_needed ?
unrolled_map(child_names(row_name, name_tree)) do child_of_row_name
(child_of_row_name, col_name)
end : nothing
col_children =
col_children_needed ?
unrolled_map(child_names(col_name, name_tree)) do child_of_col_name
(row_name, child_of_col_name)
end : nothing
n_row_children = row_children_needed ? length(row_children) : 0
n_col_children = col_children_needed ? length(col_children) : 0
# We are guaranteed that either n_row_children > 0 or n_col_children > 0.
return if n_row_children == n_col_children == 1
only_child_of_name_pair = (row_children[1][1], col_children[1][2])
((only_child_of_name_pair,),)
elseif n_row_children > 0 && n_col_children <= 1
(row_children,)
elseif n_col_children > 0 && n_row_children <= 1
(col_children,)
else
@assert n_row_children > 1 && n_col_children > 1
# If multiple row and column children are available, we might get
# different results depending on whether we expand the row or the column
# first, so we need to try both and pick the one that corresponds to the
# result with the shortest length.
(row_children, col_children)
row_name_children_needed =
unrolled_all(overlapping_name_pairs) do name_pair′
name_pair′[1] != row_name && is_child_name(name_pair′[1], row_name)
end
col_name_children_needed =
unrolled_all(overlapping_name_pairs) do name_pair′
name_pair′[2] != col_name && is_child_name(name_pair′[2], col_name)
end
row_name_children =
row_name_children_needed ? child_names(row_name, name_tree) : ()
col_name_children =
col_name_children_needed ? child_names(col_name, name_tree) : ()
# Note: We need special cases for when either row_name or col_name only has
# one child name, since automatically expanding that name can generate
# results with more expansions than are necessary. For example, it can lead
# to a situation like issubset(set1, set2) && union(set1, set2) != set2,
# where union(set1, set2) has too many expanded values.
return if length(row_name_children) > 1 && length(col_name_children) > 1 ||
length(row_name_children) == 1 && length(col_name_children) == 1
unrolled_product(row_name_children, col_name_children)
elseif length(row_name_children) > 0 && length(col_name_children) <= 1
unrolled_product(row_name_children, (col_name,))
elseif length(row_name_children) <= 1 && length(col_name_children) > 0
unrolled_product((row_name,), col_name_children)
else # length(row_name_children) == 0 && length(col_name_children) == 0
(name_pair,)
end
end

# This is required for type-stability as of Julia 1.9.
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(remove_duplicates_and_overlaps)
m.recursion_relation = dont_limit
end
for m in methods(union_values)
m.recursion_relation = dont_limit
end
Expand Down
Loading

0 comments on commit bf6c14a

Please sign in to comment.