Skip to content

Commit

Permalink
Reduce number of passes through spatialized_chunk_predictions + updat…
Browse files Browse the repository at this point in the history
…e metadata field name (#43)
  • Loading branch information
skeshive authored Jul 25, 2024
1 parent 75380aa commit b10c769
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 57 deletions.
84 changes: 50 additions & 34 deletions cloud_functions/climateiq_spatialize_chunk_predictions_cf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def spatialize_chunk_predictions(cloud_event: http.CloudEvent) -> None:
except ValueError as ve:
# Any raised ValueErrors are non-retriable so return instead of throwing an
# exception (which would trigger retries)
print(ve)
print(f"Error for {object_name}: {ve}")
return

storage_client = gcs_client.Client()
Expand Down Expand Up @@ -178,12 +178,12 @@ def _get_study_area_metadata(
not study_area_metadata
or "cell_size" not in study_area_metadata
or "crs" not in study_area_metadata
or "row_count" not in study_area_metadata
or "col_count" not in study_area_metadata
or "chunk_x_count" not in study_area_metadata
or "chunk_y_count" not in study_area_metadata
):
raise ValueError(
f'Study area "{study_area_name}" is missing one or more required '
"fields: cell_size, crs, row_count, col_count"
"fields: cell_size, crs, chunk_x_count, chunk_y_count"
)

return study_area_metadata, chunks_ref
Expand Down Expand Up @@ -289,28 +289,53 @@ def _build_spatialized_model_predictions(
)


def _add_h3_index_details(cell: pd.Series) -> pd.Series:
"""Projects the cell centroid to a H3 index.
def _add_h3_index_details(cell: pd.Series, chunk_boundary: Any) -> pd.Series:
"""Projects the cell centroid to a H3 index and adds H3 details.
Args:
cell: A cell row containing the lat and lon of the cell centroid.
chunk_boundary: A shapely.Polygon representing the chunk.
Returns:
A Series containing H3 information for the projected cell centroid.
"""
h3_index = h3.geo_to_h3(cell["lat"], cell["lon"], H3_LEVEL)
centroid_lat, centroid_lon = h3.h3_to_geo(h3_index)
boundary_xy = h3.h3_to_geo_boundary(h3_index, True)
boundary_xy = geometry.Polygon(h3.h3_to_geo_boundary(h3_index, True))
is_boundary_cell = not boundary_xy.within(chunk_boundary)

# Filter out any rows where the projected H3 centroid falls outside of the
# chunk boundary.
if not chunk_boundary.contains(geometry.Point(centroid_lon, centroid_lat)):
h3_index = None
centroid_lat = None
centroid_lon = None
boundary_xy = None
is_boundary_cell = False

return pd.Series(
{
"h3_index": h3_index,
"h3_centroid_lat": centroid_lat,
"h3_centroid_lon": centroid_lon,
"h3_boundary": geometry.Polygon(boundary_xy),
"h3_boundary": boundary_xy,
"is_boundary_cell": is_boundary_cell,
}
)


def _add_h3_index(cell: pd.Series) -> pd.Series:
"""Projects the cell centroid to a H3 index.
Args:
cell: A cell row containing the lat and lon of the cell centroid.
Returns:
A Series containing the H3 index of the projected cell centroid.
"""
return pd.Series({"h3_index": h3.geo_to_h3(cell["lat"], cell["lon"], H3_LEVEL)})


def _get_chunk_boundary(study_area_metadata: dict, chunk_metadata: dict):
"""Calculates the boundary points of the chunk.
Expand Down Expand Up @@ -376,29 +401,24 @@ def _calculate_h3_indexes(
missing required fields.
"""
# Calculate H3 information for each cell.
spatialized_predictions[
["h3_index", "h3_centroid_lat", "h3_centroid_lon", "h3_boundary"]
] = spatialized_predictions.apply(_add_h3_index_details, axis=1)

# Filter out any rows where the projected H3 centroid falls outside of the
# chunk boundary.
chunk_boundary = _get_chunk_boundary(study_area_metadata, chunk_metadata)
spatialized_predictions = spatialized_predictions[
spatialized_predictions.apply(
lambda row: chunk_boundary.contains(
geometry.Point(row["h3_centroid_lon"], row["h3_centroid_lat"])
),
axis=1,
)
]
spatialized_predictions[
[
"h3_index",
"h3_centroid_lat",
"h3_centroid_lon",
"h3_boundary",
"is_boundary_cell",
]
] = spatialized_predictions.apply(
lambda row: _add_h3_index_details(row, chunk_boundary), axis=1
)
spatialized_predictions = spatialized_predictions.dropna(how="any")

# Extract rows where the projected H3 cell is not fully contained within the chunk
# so we can aggregate prediction values across chunk boundaries.
boundary_h3_cells = spatialized_predictions[
spatialized_predictions.apply(
lambda row: not row["h3_boundary"].within(chunk_boundary),
axis=1,
)
spatialized_predictions["is_boundary_cell"]
]["h3_boundary"].unique()

return _aggregate_h3_predictions(
Expand Down Expand Up @@ -470,8 +490,8 @@ def _aggregate_h3_predictions(
if (
neighbor_x < 0
or neighbor_y < 0
or neighbor_x >= study_area_metadata["col_count"]
or neighbor_y >= study_area_metadata["row_count"]
or neighbor_x >= study_area_metadata["chunk_x_count"]
or neighbor_y >= study_area_metadata["chunk_y_count"]
):
# Chunk is outside the study area boundary.
continue
Expand Down Expand Up @@ -524,12 +544,8 @@ def _aggregate_h3_predictions(
neighbor_chunk_predictions,
)
)
# TODO: Optionally only calculate the h3_index if calculating other
# metadata is expensive
neighbor_chunk_spatialized_predictions[
["h3_index", "h3_centroid_lat", "h3_centroid_lon", "h3_boundary"]
] = neighbor_chunk_spatialized_predictions.apply(
_add_h3_index_details, axis=1
neighbor_chunk_spatialized_predictions[["h3_index"]] = (
neighbor_chunk_spatialized_predictions.apply(_add_h3_index, axis=1)
)
neighbor_chunk_spatialized_predictions = (
neighbor_chunk_spatialized_predictions[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def test_spatialize_chunk_predictions_invalid_study_area(
study_area_metadata: Dict[str, Any] = {
"name": "study_area_name",
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
} # Missing "cell_size" required field
chunks_metadata: List[Dict[str, Any]] = [
{
Expand All @@ -164,7 +164,7 @@ def test_spatialize_chunk_predictions_invalid_study_area(

assert (
'Study area "study-area-name" is missing one or more required '
"fields: cell_size, crs, row_count, col_count" in output.getvalue()
"fields: cell_size, crs, chunk_x_count, chunk_y_count" in output.getvalue()
)


Expand Down Expand Up @@ -200,8 +200,8 @@ def test_spatialize_chunk_predictions_missing_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -261,8 +261,8 @@ def test_spatialize_chunk_predictions_invalid_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -321,8 +321,8 @@ def test_spatialize_chunk_predictions_missing_predictions(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -387,8 +387,8 @@ def test_spatialize_chunk_predictions_too_many_predictions(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -449,8 +449,8 @@ def test_spatialize_chunk_predictions_missing_expected_neighbor_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -514,8 +514,8 @@ def test_spatialize_chunk_predictions_invalid_neighbor_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -588,8 +588,8 @@ def test_spatialize_chunk_predictions_neighbor_chunk_missing_predictions(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -680,8 +680,8 @@ def test_spatialize_chunk_predictions_h3_centroids_within_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -788,8 +788,8 @@ def test_spatialize_chunk_predictions_h3_centroids_outside_chunk(
"name": "study_area_name",
"cell_size": 5,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -950,8 +950,8 @@ def test_spatialize_chunk_predictions_overlapping_neighbors(
"name": "study_area_name",
"cell_size": 3,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down

0 comments on commit b10c769

Please sign in to comment.