Skip to content

Commit

Permalink
Update merge function based on live testing (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
skeshive authored Jul 16, 2024
1 parent 98ed555 commit 645f216
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
32 changes: 19 additions & 13 deletions cloud_functions/climateiq_merge_scenario_predictions_cf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)
# File name pattern for the CSVs for each scenario and chunk.
CHUNK_FILE_NAME_PATTERN = (
r"(?P<batch_id>\w+)/(?P<prediction_type>\w+)/(?P<model_id>\w+)/"
r"(?P<study_area_name>\w+)/(?P<scenario_id>\w+)/(?P<chunk_id>\w+)\.csv"
r"(?P<batch_id>[^/]+)/(?P<prediction_type>[^/]+)/(?P<model_id>[^/]+)/"
r"(?P<study_area_name>[^/]+)/(?P<scenario_id>[^/]+)/(?P<chunk_id>[^/]+)\.csv"
)
# ID for the Study Areas collection in Firestore.
STUDY_AREAS_COLLECTION_ID = "study_areas"
Expand Down Expand Up @@ -48,7 +48,12 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent):
object_name = data["name"]
match = re.match(CHUNK_FILE_NAME_PATTERN, object_name)
# Ignore files that don't match the pattern.
if not match:
if match is None:
print(
f"Invalid object name format. Expected format: '<id>/<prediction_type>/"
f"<model_id>/<study_area_name>/<scenario_id>/<chunk_id>'\n"
f"Actual name: '{object_name}'"
)
return

batch_id, prediction_type, model_id, study_area_name = (
Expand All @@ -68,7 +73,8 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent):
storage_client = storage.Client()
input_bucket = storage_client.bucket(INPUT_BUCKET_NAME)
blobs = storage_client.list_blobs(
INPUT_BUCKET_NAME, f"{batch_id}/{prediction_type}/{model_id}/{study_area_name}"
INPUT_BUCKET_NAME,
prefix=f"{batch_id}/{prediction_type}/{model_id}/{study_area_name}",
)
chunk_ids_by_scenario_id = _get_chunk_ids_to_scenario_id(blobs)

Expand Down Expand Up @@ -111,10 +117,10 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent):
blob_to_write = output_bucket.blob(output_file_name)
with blob_to_write.open("w") as fd:
# Open the blob and start writing a CSV file with the headers
# h3_index,scenario_0,scenario_1...
writer = csv.DictWriter(fd, fieldnames=["h3_index"] + scenario_ids)
# cell_code,scenario_0,scenario_1...
writer = csv.DictWriter(fd, fieldnames=["cell_code"] + scenario_ids)
writer.writeheader()
predictions_by_h3_index: dict[str, dict] = collections.defaultdict(dict)
predictions_by_cell_code: dict[str, dict] = collections.defaultdict(dict)
for scenario_id in scenario_ids:
object_name = (
f"{batch_id}/{prediction_type}/{model_id}/"
Expand All @@ -126,19 +132,19 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent):
print(f"Not found: {error}")
return
for row in rows:
predictions_by_h3_index[row["h3_index"]][scenario_id] = row[
predictions_by_cell_code[row["h3_index"]][scenario_id] = row[
"prediction"
]
for h3_index, predictions in predictions_by_h3_index.items():
for cell_code, predictions in predictions_by_cell_code.items():
missing_scenario_ids = set(scenario_ids) - set(predictions.keys())
if missing_scenario_ids:
print(
f"Not found: Missing predictions for {h3_index} for "
f"Not found: Missing predictions for {cell_code} for "
f"{', '.join(missing_scenario_ids)}."
)
return
predictions["h3_index"] = h3_index
# Output CSV will have the headers: h3_index,scenario_0,scenario_1...
predictions["cell_code"] = cell_code
# Output CSV will have the headers: cell_code,scenario_0,scenario_1...
writer.writerow(predictions)


Expand Down Expand Up @@ -232,7 +238,7 @@ def _files_complete(
def _get_file_content(bucket: storage.Bucket, object_name: str) -> list[dict]:
"""Gets the content from a Blob.
Assumes Blob content is in CSV format with headers h3_index,prediction...
Assumes Blob content is in CSV format with headers cell_code,prediction...
Args:
bucket: The GCS bucket the Blob is in.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def _create_chunk_file(
h3_indices_to_predictions: dict[str, float], tmp_path: str
) -> str:
rows = ["h3_index,prediction"] + [
f"{h3_index},{prediction}"
for h3_index, prediction in h3_indices_to_predictions.items()
f"{cell_code},{prediction}"
for cell_code, prediction in h3_indices_to_predictions.items()
]
with tempfile.NamedTemporaryFile("w+", dir=tmp_path, delete=False) as fd:
fd.write("\n".join(rows))
Expand Down Expand Up @@ -133,10 +133,10 @@ def test_merge_scenario_predictions(
main.merge_scenario_predictions(_create_pubsub_event())

expected_chunk0_contents = (
"h3_index,scenario0,scenario1\n" "h300,0.0,1.0\n" "h301,0.01,1.01\n"
"cell_code,scenario0,scenario1\n" "h300,0.0,1.0\n" "h301,0.01,1.01\n"
)
expected_chunk1_contents = (
"h3_index,scenario0,scenario1\n" "h310,0.1,1.1\n" "h311,0.11,1.11\n"
"cell_code,scenario0,scenario1\n" "h310,0.1,1.1\n" "h311,0.11,1.11\n"
)
with open(output_files["batch/flood/model/nyc/chunk0.csv"]) as fd:
assert fd.read() == expected_chunk0_contents
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_merge_scenario_predictions_missing_chunk_prints_error(

@mock.patch.object(storage, "Client", autospec=True)
@mock.patch.object(firestore, "Client", autospec=True)
def test_merge_scenario_predictions_missing_scenarios_for_h3_index_prints_error(
def test_merge_scenario_predictions_missing_scenarios_for_cell_code_prints_error(
mock_firestore_client, mock_storage_client, tmp_path
):
_create_firestore_entries(mock_firestore_client, ["scenario0", "scenario1"], 2)
Expand Down

0 comments on commit 645f216

Please sign in to comment.