diff --git a/src/variable.jl b/src/variable.jl index 1760724ed..a4ccc8e7e 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -606,7 +606,7 @@ function fast_substitute(expr, subs; operator = Nothing) args = let canfold = canfold map(args) do x x′ = fast_substitute(x, subs; operator) - canfold[] = canfold[] && !(x′ isa Symbolic) + canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ end end @@ -633,7 +633,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) args = let canfold = canfold map(args) do x x′ = fast_substitute(x, pair; operator) - canfold[] = canfold[] && !(x′ isa Symbolic) + canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) x′ end end @@ -645,6 +645,13 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) metadata(expr)) end +function is_array_of_symbolics(x) + symbolic_type(x) == ArraySymbolic() && return true + symbolic_type(x) == ScalarSymbolic() && return false + x isa AbstractArray && + any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x) +end + function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) if maybe_parent !== nothing