Skip to content

Commit

Permalink
fix get_valid_types with TypeSignature::Numeric
Browse files Browse the repository at this point in the history
  • Loading branch information
niebayes committed Jan 10, 2025
1 parent 300c2af commit 9afab65
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ fn get_valid_types(
// and their default type is double precision
if logical_data_type == NativeType::Null {
valid_type = DataType::Float64;
} else if !logical_data_type.is_numeric() {
return plan_err!(
"The signature expected NativeType::Numeric but received {logical_data_type}"
);
}

vec![vec![valid_type; *number]]
Expand Down Expand Up @@ -939,6 +943,7 @@ mod tests {

use super::*;
use arrow::datatypes::Field;
use datafusion_common::assert_contains;

#[test]
fn test_string_conversion() {
Expand Down Expand Up @@ -1003,6 +1008,67 @@ mod tests {
}
}

#[test]
fn test_get_valid_types_numeric() -> Result<()> {
let get_valid_types_flatten =
|signature: &TypeSignature, current_types: &[DataType]| {
get_valid_types(signature, current_types)
.unwrap()
.into_iter()
.flatten()
.collect::<Vec<_>>()
};

// Trivial case.
let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Int32]);
assert_eq!(got, [DataType::Int32]);

// Args are coerced into a common numeric type.
let got = get_valid_types_flatten(
&TypeSignature::Numeric(2),
&[DataType::Int32, DataType::Int64],
);
assert_eq!(got, [DataType::Int64, DataType::Int64]);

// Args are coerced into a common numeric type, specifically, int would be coerced to float.
let got = get_valid_types_flatten(
&TypeSignature::Numeric(3),
&[DataType::Int32, DataType::Int64, DataType::Float64],
);
assert_eq!(
got,
[DataType::Float64, DataType::Float64, DataType::Float64]
);

// Cannot coerce args to a common numeric type.
let got = get_valid_types(
&TypeSignature::Numeric(2),
&[DataType::Int32, DataType::Utf8],
)
.unwrap_err();
assert_contains!(
got.to_string(),
"The signature expected NativeType::Numeric but received NativeType::String"
);

// Fallbacks to float64 if the arg is of type null.
let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Null]);
assert_eq!(got, [DataType::Float64]);

// Rejects non-numeric arg.
let got = get_valid_types(
&TypeSignature::Numeric(1),
&[DataType::Timestamp(TimeUnit::Second, None)],
)
.unwrap_err();
assert_contains!(
got.to_string(),
"The signature expected NativeType::Numeric but received NativeType::Timestamp(Second, None)"
);

Ok(())
}

#[test]
fn test_get_valid_types_one_of() -> Result<()> {
let signature =
Expand Down

0 comments on commit 9afab65

Please sign in to comment.