Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce number of passes through spatialized_chunk_predictions + update metadata field name #43

Merged
merged 5 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 50 additions & 34 deletions cloud_functions/climateiq_spatialize_chunk_predictions_cf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def spatialize_chunk_predictions(cloud_event: http.CloudEvent) -> None:
except ValueError as ve:
# Any raised ValueErrors are non-retriable so return instead of throwing an
# exception (which would trigger retries)
print(ve)
print(f"Error for {object_name}: {ve}")
return

storage_client = gcs_client.Client()
Expand Down Expand Up @@ -178,12 +178,12 @@ def _get_study_area_metadata(
not study_area_metadata
or "cell_size" not in study_area_metadata
or "crs" not in study_area_metadata
or "row_count" not in study_area_metadata
or "col_count" not in study_area_metadata
or "chunk_x_count" not in study_area_metadata
or "chunk_y_count" not in study_area_metadata
):
raise ValueError(
f'Study area "{study_area_name}" is missing one or more required '
"fields: cell_size, crs, row_count, col_count"
"fields: cell_size, crs, chunk_x_count, chunk_y_count"
)

return study_area_metadata, chunks_ref
Expand Down Expand Up @@ -289,28 +289,53 @@ def _build_spatialized_model_predictions(
)


def _add_h3_index_details(cell: pd.Series) -> pd.Series:
"""Projects the cell centroid to a H3 index.
def _add_h3_index_details(cell: pd.Series, chunk_boundary: Any) -> pd.Series:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this probably isn't actually an Any ?
you can do

print(chunk_boundary.__class__)

if you aren't sure what the right type is

"""Projects the cell centroid to a H3 index and adds H3 details.

Args:
cell: A cell row containing the lat and lon of the cell centroid.
chunk_boundary: A shapely.Polygon representing the chunk.

Returns:
A Series containing H3 information for the projected cell centroid.
"""
h3_index = h3.geo_to_h3(cell["lat"], cell["lon"], H3_LEVEL)
centroid_lat, centroid_lon = h3.h3_to_geo(h3_index)
boundary_xy = h3.h3_to_geo_boundary(h3_index, True)
boundary_xy = geometry.Polygon(h3.h3_to_geo_boundary(h3_index, True))
is_boundary_cell = not boundary_xy.within(chunk_boundary)

# Filter out any rows where the projected H3 centroid falls outside of the
# chunk boundary.
if not chunk_boundary.contains(geometry.Point(centroid_lon, centroid_lat)):
h3_index = None
centroid_lat = None
centroid_lon = None
boundary_xy = None
is_boundary_cell = False

return pd.Series(
{
"h3_index": h3_index,
"h3_centroid_lat": centroid_lat,
"h3_centroid_lon": centroid_lon,
"h3_boundary": geometry.Polygon(boundary_xy),
"h3_boundary": boundary_xy,
"is_boundary_cell": is_boundary_cell,
}
)


def _add_h3_index(cell: pd.Series) -> pd.Series:
"""Projects the cell centroid to a H3 index.

Args:
cell: A cell row containing the lat and lon of the cell centroid.

Returns:
A Series containing the H3 index of the projected cell centroid.
"""
return pd.Series({"h3_index": h3.geo_to_h3(cell["lat"], cell["lon"], H3_LEVEL)})


def _get_chunk_boundary(study_area_metadata: dict, chunk_metadata: dict):
"""Calculates the boundary points of the chunk.

Expand Down Expand Up @@ -376,29 +401,24 @@ def _calculate_h3_indexes(
missing required fields.
"""
# Calculate H3 information for each cell.
spatialized_predictions[
["h3_index", "h3_centroid_lat", "h3_centroid_lon", "h3_boundary"]
] = spatialized_predictions.apply(_add_h3_index_details, axis=1)

# Filter out any rows where the projected H3 centroid falls outside of the
# chunk boundary.
chunk_boundary = _get_chunk_boundary(study_area_metadata, chunk_metadata)
spatialized_predictions = spatialized_predictions[
spatialized_predictions.apply(
lambda row: chunk_boundary.contains(
geometry.Point(row["h3_centroid_lon"], row["h3_centroid_lat"])
),
axis=1,
)
]
spatialized_predictions[
[
"h3_index",
"h3_centroid_lat",
"h3_centroid_lon",
"h3_boundary",
"is_boundary_cell",
]
] = spatialized_predictions.apply(
lambda row: _add_h3_index_details(row, chunk_boundary), axis=1
)
spatialized_predictions = spatialized_predictions.dropna(how="any")

# Extract rows where the projected H3 cell is not fully contained within the chunk
# so we can aggregate prediction values across chunk boundaries.
boundary_h3_cells = spatialized_predictions[
spatialized_predictions.apply(
lambda row: not row["h3_boundary"].within(chunk_boundary),
axis=1,
)
spatialized_predictions["is_boundary_cell"]
]["h3_boundary"].unique()

return _aggregate_h3_predictions(
Expand Down Expand Up @@ -470,8 +490,8 @@ def _aggregate_h3_predictions(
if (
neighbor_x < 0
or neighbor_y < 0
or neighbor_x >= study_area_metadata["col_count"]
or neighbor_y >= study_area_metadata["row_count"]
or neighbor_x >= study_area_metadata["chunk_x_count"]
or neighbor_y >= study_area_metadata["chunk_y_count"]
):
# Chunk is outside the study area boundary.
continue
Expand Down Expand Up @@ -524,12 +544,8 @@ def _aggregate_h3_predictions(
neighbor_chunk_predictions,
)
)
# TODO: Optionally only calculate the h3_index if calculating other
# metadata is expensive
neighbor_chunk_spatialized_predictions[
["h3_index", "h3_centroid_lat", "h3_centroid_lon", "h3_boundary"]
] = neighbor_chunk_spatialized_predictions.apply(
_add_h3_index_details, axis=1
neighbor_chunk_spatialized_predictions[["h3_index"]] = (
neighbor_chunk_spatialized_predictions.apply(_add_h3_index, axis=1)
)
neighbor_chunk_spatialized_predictions = (
neighbor_chunk_spatialized_predictions[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def test_spatialize_chunk_predictions_invalid_study_area(
study_area_metadata: Dict[str, Any] = {
"name": "study_area_name",
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
} # Missing "cell_size" required field
chunks_metadata: List[Dict[str, Any]] = [
{
Expand All @@ -164,7 +164,7 @@ def test_spatialize_chunk_predictions_invalid_study_area(

assert (
'Study area "study-area-name" is missing one or more required '
"fields: cell_size, crs, row_count, col_count" in output.getvalue()
"fields: cell_size, crs, chunk_x_count, chunk_y_count" in output.getvalue()
)


Expand Down Expand Up @@ -200,8 +200,8 @@ def test_spatialize_chunk_predictions_missing_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -261,8 +261,8 @@ def test_spatialize_chunk_predictions_invalid_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -321,8 +321,8 @@ def test_spatialize_chunk_predictions_missing_predictions(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -387,8 +387,8 @@ def test_spatialize_chunk_predictions_too_many_predictions(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -449,8 +449,8 @@ def test_spatialize_chunk_predictions_missing_expected_neighbor_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -514,8 +514,8 @@ def test_spatialize_chunk_predictions_invalid_neighbor_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -588,8 +588,8 @@ def test_spatialize_chunk_predictions_neighbor_chunk_missing_predictions(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -680,8 +680,8 @@ def test_spatialize_chunk_predictions_h3_centroids_within_chunk(
"name": "study_area_name",
"cell_size": 10,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -788,8 +788,8 @@ def test_spatialize_chunk_predictions_h3_centroids_outside_chunk(
"name": "study_area_name",
"cell_size": 5,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -950,8 +950,8 @@ def test_spatialize_chunk_predictions_overlapping_neighbors(
"name": "study_area_name",
"cell_size": 3,
"crs": "EPSG:32618",
"row_count": 2,
"col_count": 3,
"chunk_y_count": 2,
"chunk_x_count": 3,
}
chunks_metadata: List[Dict[str, Any]] = [
{
Expand Down