From aa2e77b1bb7d2d6352a14909b4b6bd2ff619fca0 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sun, 5 May 2024 22:06:47 +0200 Subject: [PATCH] fix: Fix casting decimal to decimal for high precision (#16049) --- .../polars-arrow/src/compute/cast/decimal_to.rs | 8 +++----- py-polars/tests/unit/operations/test_cast.py | 12 ++++++++++++ py-polars/tests/unit/sql/test_numeric.py | 16 +++++++--------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/decimal_to.rs b/crates/polars-arrow/src/compute/cast/decimal_to.rs index 449bc8dd68cf..dd2f29e1a443 100644 --- a/crates/polars-arrow/src/compute/cast/decimal_to.rs +++ b/crates/polars-arrow/src/compute/cast/decimal_to.rs @@ -12,15 +12,13 @@ fn decimal_to_decimal_impl Option>( to_precision: usize, to_scale: usize, ) -> PrimitiveArray { - let min_for_precision = 9_i128 - .saturating_pow(1 + to_precision as u32) - .saturating_neg(); - let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + let upper_bound_for_precision = 10_i128.saturating_pow(to_precision as u32); + let lower_bound_for_precision = upper_bound_for_precision.saturating_neg(); let values = from.iter().map(|x| { x.and_then(|x| { op(*x).and_then(|x| { - if x > max_for_precision || x < min_for_precision { + if x >= upper_bound_for_precision || x <= lower_bound_for_precision { None } else { Some(x) diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index 73722c9660f5..7d5804473eb6 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -640,3 +640,15 @@ def test_cast_array_to_different_width() -> None: pl.InvalidOperationError, match="cannot cast Array to a different width" ): s.cast(pl.Array(pl.Int16, 3)) + + +def test_cast_decimal_to_decimal_high_precision() -> None: + precision = 22 + values = [Decimal("9" * precision)] + s = pl.Series(values, dtype=pl.Decimal(None, 0)) + + target_dtype = pl.Decimal(precision, 0) + result = s.cast(target_dtype) + + assert result.dtype == target_dtype + assert result.to_list() == values diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index ca54df80be98..5e23189c79f7 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -50,7 +50,7 @@ def test_modulo() -> None: ("value", "sqltype", "prec_scale", "expected_value", "expected_dtype"), [ (64.5, "numeric", "(3,1)", D("64.5"), pl.Decimal(3, 1)), - (512.5, "decimal", "(3,1)", D("512.5"), pl.Decimal(3, 1)), + (512.5, "decimal", "(4,1)", D("512.5"), pl.Decimal(4, 1)), (512.5, "numeric", "(4,0)", D("512"), pl.Decimal(4, 0)), (-1024.75, "decimal", "(10,0)", D("-1024"), pl.Decimal(10, 0)), (-1024.75, "numeric", "(10)", D("-1024"), pl.Decimal(10, 0)), @@ -67,18 +67,16 @@ def test_numeric_decimal_type( with pl.Config(activate_decimals=True): df = pl.DataFrame({"n": [value]}) with pl.SQLContext(df=df) as ctx: - out = ctx.execute( + result = ctx.execute( f""" SELECT n::{sqltype}{prec_scale} AS "dec" FROM df """ ) - assert_frame_equal( - out.collect(), - pl.DataFrame( - data={"dec": [expected_value]}, - schema={"dec": expected_dtype}, - ), - ) + expected = pl.LazyFrame( + data={"dec": [expected_value]}, + schema={"dec": expected_dtype}, + ) + assert_frame_equal(result, expected) @pytest.mark.parametrize(