diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 96bb5c4b2d8f..7778f2db126d 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -551,6 +551,10 @@ fn get_valid_types( // and their default type is double precision if logical_data_type == NativeType::Null { valid_type = DataType::Float64; + } else if !logical_data_type.is_numeric() { + return plan_err!( + "The signature expected NativeType::Numeric but received {logical_data_type}" + ); } vec![vec![valid_type; *number]] @@ -939,6 +943,7 @@ mod tests { use super::*; use arrow::datatypes::Field; + use datafusion_common::assert_contains; #[test] fn test_string_conversion() { @@ -1003,6 +1008,67 @@ mod tests { } } + #[test] + fn test_get_valid_types_numeric() -> Result<()> { + let get_valid_types_flatten = + |signature: &TypeSignature, current_types: &[DataType]| { + get_valid_types(signature, current_types) + .unwrap() + .into_iter() + .flatten() + .collect::>() + }; + + // Trivial case. + let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Int32]); + assert_eq!(got, [DataType::Int32]); + + // Args are coerced into a common numeric type. + let got = get_valid_types_flatten( + &TypeSignature::Numeric(2), + &[DataType::Int32, DataType::Int64], + ); + assert_eq!(got, [DataType::Int64, DataType::Int64]); + + // Args are coerced into a common numeric type, specifically, int would be coerced to float. + let got = get_valid_types_flatten( + &TypeSignature::Numeric(3), + &[DataType::Int32, DataType::Int64, DataType::Float64], + ); + assert_eq!( + got, + [DataType::Float64, DataType::Float64, DataType::Float64] + ); + + // Cannot coerce args to a common numeric type. + let got = get_valid_types( + &TypeSignature::Numeric(2), + &[DataType::Int32, DataType::Utf8], + ) + .unwrap_err(); + assert_contains!( + got.to_string(), + "The signature expected NativeType::Numeric but received NativeType::String" + ); + + // Fallbacks to float64 if the arg is of type null. + let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Null]); + assert_eq!(got, [DataType::Float64]); + + // Rejects non-numeric arg. + let got = get_valid_types( + &TypeSignature::Numeric(1), + &[DataType::Timestamp(TimeUnit::Second, None)], + ) + .unwrap_err(); + assert_contains!( + got.to_string(), + "The signature expected NativeType::Numeric but received NativeType::Timestamp(Second, None)" + ); + + Ok(()) + } + #[test] fn test_get_valid_types_one_of() -> Result<()> { let signature = diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index e86d78a62353..37b5a378fc02 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -126,15 +126,15 @@ statement error SELECT abs(1, 2); # abs: unsupported argument type -query error This feature is not implemented: Unsupported data type Utf8 for function abs +query error DataFusion error: Error during planning: The signature expected NativeType::Numeric but received NativeType::String SELECT abs('foo'); # abs: numeric string # TODO: In Postgres, '-1.2' is unknown type and interpreted to float8 so they don't fail on this query -query error DataFusion error: This feature is not implemented: Unsupported data type Utf8 for function abs +query error DataFusion error: Error during planning: The signature expected NativeType::Numeric but received NativeType::String select abs('-1.2'); -query error DataFusion error: This feature is not implemented: Unsupported data type Utf8 for function abs +query error DataFusion error: Error during planning: The signature expected NativeType::Numeric but received NativeType::String select abs(arrow_cast('-1.2', 'Utf8')); statement ok