Skip to content

Commit

Permalink
Modify climateiq_spatialize_chunk_predictions_cf based on live testing (
Browse files Browse the repository at this point in the history
  • Loading branch information
skeshive authored Jul 16, 2024
1 parent 770c4ac commit 98ed555
Show file tree
Hide file tree
Showing 2 changed files with 399 additions and 322 deletions.
112 changes: 62 additions & 50 deletions cloud_functions/climateiq_spatialize_chunk_predictions_cf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from typing import Any
from cloudevents import http
from google.cloud import firestore, storage
from google.cloud import firestore_v1
from google.cloud.storage import client as gcs_client
from shapely import geometry
from h3 import h3

Expand All @@ -28,60 +29,64 @@
# Triggered from a message on the "climateiq-spatialize-and-export-predictions"
# Pub/Sub topic.
@functions_framework.cloud_event
def subscribe(cloud_event: http.CloudEvent) -> None:
"""This function spatializes model predictions for a single chunk and outputs a
CSV file to GCS containing H3 indexes along with associated predictions.
def spatialize_chunk_predictions(cloud_event: http.CloudEvent) -> None:
"""This function spatializes model predictions for a single chunk.
Spatialized model predictions are outputtted to a CSV file in GCS,
containing H3 indexes along with associated predictions.
Args:
cloud_event: The CloudEvent representing the Pub/Sub message.
Raises:
ValueError: If the object name format, study area metadata, chunk / neighbor
chunk metadata or predictions file format is invalid.
"""
object_name = base64.b64decode(cloud_event.data["message"]["data"]).decode()

# Extract components from the object name.
path = pathlib.PurePosixPath(object_name)
if len(path.parts) != 6:
raise ValueError(
"Invalid object name format. Expected format: '<id>/<prediction_type>/"
"<model_id>/<study_area_name>/<scenario_id>/<chunk_id>'"
print(
f"Invalid object name format. Expected format: '<id>/<prediction_type>/"
"<model_id>/<study_area_name>/<scenario_id>/<chunk_id>'\n"
f"Actual name: '{object_name}'"
)
return

id, prediction_type, model_id, study_area_name, scenario_id, chunk_id = path.parts
try:
predictions = _read_chunk_predictions(object_name)
study_area_metadata, chunks_ref = _get_study_area_metadata(study_area_name)
chunk_metadata = _get_chunk_metadata(chunks_ref, chunk_id)

predictions = _read_chunk_predictions(object_name)
study_area_metadata, chunks_ref = _get_study_area_metadata(study_area_name)
chunk_metadata = _get_chunk_metadata(study_area_metadata, chunk_id)

spatialized_predictions = _build_spatialized_model_predictions(
study_area_metadata, chunk_metadata, predictions
)
spatialized_predictions = _build_spatialized_model_predictions(
study_area_metadata, chunk_metadata, predictions
)

h3_predictions = _calculate_h3_indexes(
study_area_metadata,
chunk_metadata,
spatialized_predictions,
object_name,
chunks_ref,
)
h3_predictions = _calculate_h3_indexes(
study_area_metadata,
chunk_metadata,
spatialized_predictions,
object_name,
chunks_ref,
)
except ValueError as ve:
# Any raised ValueErrors are non-retriable so return instead of throwing an
# exception (which would trigger retries)
print(ve)
return

storage_client = storage.Client()
storage_client = gcs_client.Client()
bucket = storage_client.bucket(OUTPUT_BUCKET_NAME)
blob = bucket.blob(
f"{id}/{prediction_type}/{model_id}/{study_area_name}/{scenario_id}/{chunk_id}"
".csv"
)
with blob.open("w+") as fd:
with blob.open("w") as fd:
h3_predictions.to_csv(fd)


# TODO: Modify this logic once CNN output schema is confirmed. Also update to
# account for errors and special values.
def _read_chunk_predictions(object_name: str) -> np.ndarray:
"""Reads model predictions for a given chunk from GCS and outputs
these predictions in a 2D array.
"""Reads model predictions for a given chunk from GCS.
Args:
object_name: The name of the chunk object to read.
Expand All @@ -92,7 +97,7 @@ def _read_chunk_predictions(object_name: str) -> np.ndarray:
Raises:
ValueError: If the predictions file format is invalid.
"""
storage_client = storage.Client()
storage_client = gcs_client.Client()
bucket = storage_client.bucket(INPUT_BUCKET_NAME)
blob = bucket.blob(object_name)

Expand All @@ -114,8 +119,7 @@ def _read_chunk_predictions(object_name: str) -> np.ndarray:
def _read_neighbor_chunk_predictions(
object_name: str, neighbor_chunk_id: str
) -> np.ndarray:
"""Reads model predictions for a neighbor chunk from GCS and outputs
these predictions in a 2D array.
"""Reads model predictions for a neighbor chunk from GCS.
Args:
object_name: The name of the chunk object this cloud function is currently
Expand All @@ -131,8 +135,9 @@ def _read_neighbor_chunk_predictions(
path = pathlib.PurePosixPath(object_name)
if len(path.parts) != 6:
raise ValueError(
"Invalid object name format. Expected format: '<id>/<prediction_type>/"
"<model_id>/<study_area_name>/<scenario_id>/<chunk_id>"
f"Invalid object name format. Expected format: '<id>/<prediction_type>/"
"<model_id>/<study_area_name>/<scenario_id>/<chunk_id>'\n"
f"Actual name: '{object_name}'"
)
*prefix, current_chunk_id = path.parts
neighbor_object_name = pathlib.PurePosixPath(*prefix, neighbor_chunk_id)
Expand All @@ -156,7 +161,7 @@ def _get_study_area_metadata(
missing required fields.
"""
# TODO: Consider refactoring this to use library from climateiq-cnn repo.
db = firestore.Client()
db = firestore_v1.Client()

study_area_ref = db.collection(STUDY_AREAS_ID).document(study_area_name)
chunks_ref = study_area_ref.collection(CHUNKS_ID)
Expand All @@ -165,17 +170,20 @@ def _get_study_area_metadata(
if not study_area_doc.exists:
raise ValueError(f'Study area "{study_area_name}" does not exist')

if len(chunks_ref.get()) == 0:
raise ValueError(f'Study area "{study_area_name}" is missing chunks')

study_area_metadata = study_area_doc.to_dict()
if (
"cell_size" not in study_area_metadata
not study_area_metadata
or "cell_size" not in study_area_metadata
or "crs" not in study_area_metadata
or "chunks" not in study_area_metadata
or "row_count" not in study_area_metadata
or "col_count" not in study_area_metadata
):
raise ValueError(
f'Study area "{study_area_name}" is missing one or more required '
"fields: cell_size, crs, chunks, row_count, col_count"
"fields: cell_size, crs, row_count, col_count"
)

return study_area_metadata, chunks_ref
Expand All @@ -200,12 +208,11 @@ def _chunk_metadata_fields_valid(chunk_metadata: dict) -> bool:
)


def _get_chunk_metadata(study_area_metadata: dict, chunk_id: str) -> dict:
def _get_chunk_metadata(chunks_ref: Any, chunk_id: str) -> dict:
"""Retrieves metadata for a specific chunk within a study area.
Args:
study_area_metadata: A dictionary containing metadata for the
study area.
chunks_ref: A reference to the chunks collection in Firestore
chunk_id: The id of the chunk to retrieve metadata for.
Returns:
Expand All @@ -215,12 +222,13 @@ def _get_chunk_metadata(study_area_metadata: dict, chunk_id: str) -> dict:
ValueError: If the specified chunk does not exist or its metadata is
missing required fields.
"""
chunks = study_area_metadata["chunks"]
chunk_metadata = chunks.get(chunk_id)
chunk_metadata = chunks_ref.document(chunk_id).get()

if chunk_metadata is None:
if not chunk_metadata.exists:
raise ValueError(f'Chunk "{chunk_id}" does not exist')

chunk_metadata = chunk_metadata.to_dict()

if not _chunk_metadata_fields_valid(chunk_metadata):
raise ValueError(
f'Chunk "{chunk_id}" is missing one or more required fields: '
Expand All @@ -233,8 +241,7 @@ def _get_chunk_metadata(study_area_metadata: dict, chunk_id: str) -> dict:
def _build_spatialized_model_predictions(
study_area_metadata: dict, chunk_metadata: dict, predictions: np.ndarray
) -> pd.DataFrame:
"""Builds a DataFrame containing the lat/lon coordinates of each cell's
center point.
"""Builds a DF containing the lat/lon coordinates of each cell's center point.
Args:
study_area_metadata: A dictionary containing metadata for the study
Expand Down Expand Up @@ -469,17 +476,22 @@ def _aggregate_h3_predictions(
# Chunk is outside the study area boundary.
continue
query = (
chunks_ref.where("x_index", "==", neighbor_x)
.where("y_index", "==", neighbor_y)
chunks_ref.where(
filter=firestore_v1.base_query.FieldFilter("x_index", "==", neighbor_x)
)
.where(
filter=firestore_v1.base_query.FieldFilter("y_index", "==", neighbor_y)
)
.limit(1)
)
chunk_doc = query.get()
if not chunk_doc.exists:
chunk_docs = query.get()
if len(chunk_docs) == 0:
raise ValueError(
f"Neighbor chunk at index {neighbor_x, neighbor_y} is missing from the "
"study area."
)

chunk_doc = chunk_docs[0]
neighbor_chunk_id = chunk_doc.id
neighbor_chunk_metadata = chunk_doc.to_dict()
if not _chunk_metadata_fields_valid(neighbor_chunk_metadata):
Expand Down
Loading

0 comments on commit 98ed555

Please sign in to comment.