Skip to content

Commit

Permalink
Handle pandas timestamp with nanosecs precision (ray-project#49370)
Browse files Browse the repository at this point in the history
## Why are these changes needed?
Handle pandas timestamp with nanosecs precision

## Related issue number

"Closes ray-project#49297"

---------

Signed-off-by: Srinath Krishnamachari <[email protected]>
Signed-off-by: Roshan Kathawate <[email protected]>
  • Loading branch information
srinathk10 authored and roshankathawate committed Jan 9, 2025
1 parent 6027a5d commit e801e49
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 16 deletions.
92 changes: 79 additions & 13 deletions python/ray/data/_internal/numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,94 @@ def validate_numpy_batch(batch: Union[Dict[str, np.ndarray], Dict[str, list]]) -


def _detect_highest_datetime_precision(datetime_list: List[datetime]) -> str:
highest_precision = "D"
"""Detect the highest precision for a list of datetime objects.
Args:
datetime_list: List of datetime objects.
Returns:
A string representing the highest precision among the datetime objects
('D', 's', 'ms', 'us', 'ns').
"""
# Define precision hierarchy
precision_hierarchy = ["D", "s", "ms", "us", "ns"]
highest_precision_index = 0 # Start with the lowest precision ("D")

for dt in datetime_list:
if dt.microsecond != 0 and dt.microsecond % 1000 != 0:
highest_precision = "us"
# Safely get the nanosecond value using getattr for backward compatibility
nanosecond = getattr(dt, "nanosecond", 0)
if nanosecond != 0:
current_precision = "ns"
elif dt.microsecond != 0:
# Check if the microsecond precision is exactly millisecond
if dt.microsecond % 1000 == 0:
current_precision = "ms"
else:
current_precision = "us"
elif dt.second != 0 or dt.minute != 0 or dt.hour != 0:
# pyarrow does not support h or m, use s for those cases to
current_precision = "s"
else:
current_precision = "D"

# Update highest_precision_index based on the hierarchy
current_index = precision_hierarchy.index(current_precision)
highest_precision_index = max(highest_precision_index, current_index)

# Stop early if highest possible precision is reached
if highest_precision_index == len(precision_hierarchy) - 1:
break
elif dt.microsecond != 0 and dt.microsecond % 1000 == 0:
highest_precision = "ms"
elif dt.hour != 0 or dt.minute != 0 or dt.second != 0:
# pyarrow does not support h or m, use s for those cases too
highest_precision = "s"

return highest_precision
return precision_hierarchy[highest_precision_index]


def _convert_to_datetime64(dt: datetime, precision: str) -> np.datetime64:
"""
Converts a datetime object to a numpy datetime64 object with the specified
precision.
Args:
dt: A datetime object to be converted.
precision: The desired precision for the datetime64 conversion. Possible
values are 'D', 's', 'ms', 'us', 'ns'.
Returns:
np.datetime64: A numpy datetime64 object with the specified precision.
"""
if precision == "ns":
# Calculate nanoseconds from microsecond and nanosecond
microseconds_as_ns = dt.microsecond * 1000
# Use getattr for backward compatibility where nanosecond attribute may not
# exist
nanoseconds = getattr(dt, "nanosecond", 0)
total_nanoseconds = microseconds_as_ns + nanoseconds
# Create datetime64 from base datetime with microsecond precision
base_dt = np.datetime64(dt, "us")
# Add remaining nanoseconds as timedelta
return base_dt + np.timedelta64(total_nanoseconds - microseconds_as_ns, "ns")
else:
return np.datetime64(dt).astype(f"datetime64[{precision}]")


def _convert_datetime_list_to_array(datetime_list: List[datetime]) -> np.ndarray:
"""Convert a list of datetime objects to a NumPy array of datetime64 with proper
precision.
Args:
datetime_list (List[datetime]): A list of `datetime` objects to be converted.
Each `datetime` object represents a specific point in time.
Returns:
np.ndarray: A NumPy array containing the `datetime64` values of the datetime
objects from the input list, with the appropriate precision (e.g., nanoseconds,
microseconds, milliseconds, etc.).
"""
# Detect the highest precision for the datetime objects
precision = _detect_highest_datetime_precision(datetime_list)

return np.array(
[np.datetime64(dt, precision) for dt in datetime_list],
dtype=f"datetime64[{precision}]",
)
# Convert each datetime to the corresponding numpy datetime64 with the appropriate
# precision
return np.array([_convert_to_datetime64(dt, precision) for dt in datetime_list])


def convert_to_numpy(column_values: Any) -> np.ndarray:
Expand Down
27 changes: 27 additions & 0 deletions python/ray/data/tests/test_arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Union

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from pyarrow import parquet as pq
Expand Down Expand Up @@ -355,6 +356,32 @@ def test_build_block_with_null_column(ray_start_regular_shared):
assert np.array_equal(rows[1]["array"], np.zeros((2, 2)))


def test_arrow_block_timestamp_ns(ray_start_regular_shared):
# Input data with nanosecond precision timestamps
data_rows = [
{"col1": 1, "col2": pd.Timestamp("2023-01-01T00:00:00.123456789")},
{"col1": 2, "col2": pd.Timestamp("2023-01-01T01:15:30.987654321")},
{"col1": 3, "col2": pd.Timestamp("2023-01-01T02:30:15.111111111")},
{"col1": 4, "col2": pd.Timestamp("2023-01-01T03:45:45.222222222")},
{"col1": 5, "col2": pd.Timestamp("2023-01-01T05:00:00.333333333")},
]

# Initialize ArrowBlockBuilder
arrow_builder = ArrowBlockBuilder()
for row in data_rows:
arrow_builder.add(row)
arrow_block = arrow_builder.build()

assert arrow_block.schema.field("col2").type == pa.timestamp("ns")
for i, row in enumerate(data_rows):
result_timestamp = arrow_block["col2"][i].as_py()
# Convert both values to pandas Timestamp to preserve nanosecond precision for
# comparison.
assert pd.Timestamp(row["col2"]) == pd.Timestamp(
result_timestamp
), f"Timestamp mismatch at row {i} in ArrowBlockBuilder output"


def test_arrow_nan_element():
ds = ray.data.from_items(
[
Expand Down
138 changes: 138 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,144 @@ def map_generator(item: dict) -> Iterator[int]:
]


# Helper function to process timestamp data in nanoseconds
def process_timestamp_data(row):
# Convert numpy.datetime64 to pd.Timestamp if needed
if isinstance(row["timestamp"], np.datetime64):
row["timestamp"] = pd.Timestamp(row["timestamp"])

# Add 1ns to timestamp
row["timestamp"] = row["timestamp"] + pd.Timedelta(1, "ns")

# Ensure the timestamp column is in the expected dtype (datetime64[ns])
row["timestamp"] = pd.to_datetime(row["timestamp"], errors="raise")

return row


def process_timestamp_data_batch_arrow(batch: pa.Table) -> pa.Table:
# Convert pyarrow Table to pandas DataFrame to process the timestamp column
df = batch.to_pandas()

df["timestamp"] = df["timestamp"].apply(
lambda x: pd.Timestamp(x) if isinstance(x, np.datetime64) else x
)

# Add 1ns to timestamp
df["timestamp"] = df["timestamp"] + pd.Timedelta(1, "ns")

# Convert back to pyarrow Table
return pa.table(df)


def process_timestamp_data_batch_pandas(batch: pd.DataFrame) -> pd.DataFrame:
# Add 1ns to timestamp column
batch["timestamp"] = batch["timestamp"] + pd.Timedelta(1, "ns")
return batch


@pytest.mark.parametrize(
"df, expected_df",
[
pytest.param(
pd.DataFrame(
{
"id": [1, 2, 3],
"timestamp": pd.to_datetime(
[
"2024-01-01 00:00:00.123456789",
"2024-01-02 00:00:00.987654321",
"2024-01-03 00:00:00.111222333",
]
),
"value": [10.123456789, 20.987654321, 30.111222333],
}
),
pd.DataFrame(
{
"id": [1, 2, 3],
"timestamp": pd.to_datetime(
[
"2024-01-01 00:00:00.123456790",
"2024-01-02 00:00:00.987654322",
"2024-01-03 00:00:00.111222334",
]
),
"value": [10.123456789, 20.987654321, 30.111222333],
}
),
id="nanoseconds_increment",
)
],
)
def test_map_batches_timestamp_nanosecs(df, expected_df, ray_start_regular_shared):
"""Verify handling timestamp with nanosecs in map_batches"""
ray_data = ray.data.from_pandas(df)

# Using pyarrow format
result_arrow = ray_data.map_batches(
process_timestamp_data_batch_arrow, batch_format="pyarrow"
)
processed_df_arrow = result_arrow.to_pandas()
processed_df_arrow["timestamp"] = processed_df_arrow["timestamp"].astype(
"datetime64[ns]"
)
pd.testing.assert_frame_equal(processed_df_arrow, expected_df)

# Using pandas format
result_pandas = ray_data.map_batches(
process_timestamp_data_batch_pandas, batch_format="pandas"
)
processed_df_pandas = result_pandas.to_pandas()
processed_df_pandas["timestamp"] = processed_df_pandas["timestamp"].astype(
"datetime64[ns]"
)
pd.testing.assert_frame_equal(processed_df_pandas, expected_df)


@pytest.mark.parametrize(
"df, expected_df",
[
pytest.param(
pd.DataFrame(
{
"id": [1, 2, 3],
"timestamp": pd.to_datetime(
[
"2024-01-01 00:00:00.123456789",
"2024-01-02 00:00:00.987654321",
"2024-01-03 00:00:00.111222333",
]
),
"value": [10.123456789, 20.987654321, 30.111222333],
}
),
pd.DataFrame(
{
"id": [1, 2, 3],
"timestamp": pd.to_datetime(
[
"2024-01-01 00:00:00.123456790",
"2024-01-02 00:00:00.987654322",
"2024-01-03 00:00:00.111222334",
]
),
"value": [10.123456789, 20.987654321, 30.111222333],
}
),
id="nanoseconds_increment_map",
)
],
)
def test_map_timestamp_nanosecs(df, expected_df, ray_start_regular_shared):
"""Verify handling timestamp with nanosecs in map"""
ray_data = ray.data.from_pandas(df)
result = ray_data.map(process_timestamp_data)
processed_df = result.to_pandas()
processed_df["timestamp"] = processed_df["timestamp"].astype("datetime64[ns]")
pd.testing.assert_frame_equal(processed_df, expected_df)


def test_add_column(ray_start_regular_shared):
"""Tests the add column API."""

Expand Down
Loading

0 comments on commit e801e49

Please sign in to comment.