Skip to content

Commit

Permalink
fix: Fix casting decimal to decimal for high precision (#16049)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored May 5, 2024
1 parent db77896 commit aa2e77b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
8 changes: 3 additions & 5 deletions crates/polars-arrow/src/compute/cast/decimal_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ fn decimal_to_decimal_impl<F: Fn(i128) -> Option<i128>>(
to_precision: usize,
to_scale: usize,
) -> PrimitiveArray<i128> {
let min_for_precision = 9_i128
.saturating_pow(1 + to_precision as u32)
.saturating_neg();
let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32);
let upper_bound_for_precision = 10_i128.saturating_pow(to_precision as u32);
let lower_bound_for_precision = upper_bound_for_precision.saturating_neg();

let values = from.iter().map(|x| {
x.and_then(|x| {
op(*x).and_then(|x| {
if x > max_for_precision || x < min_for_precision {
if x >= upper_bound_for_precision || x <= lower_bound_for_precision {
None
} else {
Some(x)
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,15 @@ def test_cast_array_to_different_width() -> None:
pl.InvalidOperationError, match="cannot cast Array to a different width"
):
s.cast(pl.Array(pl.Int16, 3))


def test_cast_decimal_to_decimal_high_precision() -> None:
precision = 22
values = [Decimal("9" * precision)]
s = pl.Series(values, dtype=pl.Decimal(None, 0))

target_dtype = pl.Decimal(precision, 0)
result = s.cast(target_dtype)

assert result.dtype == target_dtype
assert result.to_list() == values
16 changes: 7 additions & 9 deletions py-polars/tests/unit/sql/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_modulo() -> None:
("value", "sqltype", "prec_scale", "expected_value", "expected_dtype"),
[
(64.5, "numeric", "(3,1)", D("64.5"), pl.Decimal(3, 1)),
(512.5, "decimal", "(3,1)", D("512.5"), pl.Decimal(3, 1)),
(512.5, "decimal", "(4,1)", D("512.5"), pl.Decimal(4, 1)),
(512.5, "numeric", "(4,0)", D("512"), pl.Decimal(4, 0)),
(-1024.75, "decimal", "(10,0)", D("-1024"), pl.Decimal(10, 0)),
(-1024.75, "numeric", "(10)", D("-1024"), pl.Decimal(10, 0)),
Expand All @@ -67,18 +67,16 @@ def test_numeric_decimal_type(
with pl.Config(activate_decimals=True):
df = pl.DataFrame({"n": [value]})
with pl.SQLContext(df=df) as ctx:
out = ctx.execute(
result = ctx.execute(
f"""
SELECT n::{sqltype}{prec_scale} AS "dec" FROM df
"""
)
assert_frame_equal(
out.collect(),
pl.DataFrame(
data={"dec": [expected_value]},
schema={"dec": expected_dtype},
),
)
expected = pl.LazyFrame(
data={"dec": [expected_value]},
schema={"dec": expected_dtype},
)
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
Expand Down

0 comments on commit aa2e77b

Please sign in to comment.