From 645f21633720bb6cae06b04f33afec0ab2892683 Mon Sep 17 00:00:00 2001 From: Shreya Keshive Date: Tue, 16 Jul 2024 16:26:28 -0400 Subject: [PATCH] Update merge function based on live testing (#31) --- .../main.py | 32 +++++++++++-------- .../main_test.py | 10 +++--- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/cloud_functions/climateiq_merge_scenario_predictions_cf/main.py b/cloud_functions/climateiq_merge_scenario_predictions_cf/main.py index f9c344f..5afe035 100644 --- a/cloud_functions/climateiq_merge_scenario_predictions_cf/main.py +++ b/cloud_functions/climateiq_merge_scenario_predictions_cf/main.py @@ -19,8 +19,8 @@ ) # File name pattern for the CSVs for each scenario and chunk. CHUNK_FILE_NAME_PATTERN = ( - r"(?P\w+)/(?P\w+)/(?P\w+)/" - r"(?P\w+)/(?P\w+)/(?P\w+)\.csv" + r"(?P[^/]+)/(?P[^/]+)/(?P[^/]+)/" + r"(?P[^/]+)/(?P[^/]+)/(?P[^/]+)\.csv" ) # ID for the Study Areas collection in Firestore. STUDY_AREAS_COLLECTION_ID = "study_areas" @@ -48,7 +48,12 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent): object_name = data["name"] match = re.match(CHUNK_FILE_NAME_PATTERN, object_name) # Ignore files that don't match the pattern. - if not match: + if match is None: + print( + f"Invalid object name format. Expected format: '//" + f"///'\n" + f"Actual name: '{object_name}'" + ) return batch_id, prediction_type, model_id, study_area_name = ( @@ -68,7 +73,8 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent): storage_client = storage.Client() input_bucket = storage_client.bucket(INPUT_BUCKET_NAME) blobs = storage_client.list_blobs( - INPUT_BUCKET_NAME, f"{batch_id}/{prediction_type}/{model_id}/{study_area_name}" + INPUT_BUCKET_NAME, + prefix=f"{batch_id}/{prediction_type}/{model_id}/{study_area_name}", ) chunk_ids_by_scenario_id = _get_chunk_ids_to_scenario_id(blobs) @@ -111,10 +117,10 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent): blob_to_write = output_bucket.blob(output_file_name) with blob_to_write.open("w") as fd: # Open the blob and start writing a CSV file with the headers - # h3_index,scenario_0,scenario_1... - writer = csv.DictWriter(fd, fieldnames=["h3_index"] + scenario_ids) + # cell_code,scenario_0,scenario_1... + writer = csv.DictWriter(fd, fieldnames=["cell_code"] + scenario_ids) writer.writeheader() - predictions_by_h3_index: dict[str, dict] = collections.defaultdict(dict) + predictions_by_cell_code: dict[str, dict] = collections.defaultdict(dict) for scenario_id in scenario_ids: object_name = ( f"{batch_id}/{prediction_type}/{model_id}/" @@ -126,19 +132,19 @@ def merge_scenario_predictions(cloud_event: http.CloudEvent): print(f"Not found: {error}") return for row in rows: - predictions_by_h3_index[row["h3_index"]][scenario_id] = row[ + predictions_by_cell_code[row["h3_index"]][scenario_id] = row[ "prediction" ] - for h3_index, predictions in predictions_by_h3_index.items(): + for cell_code, predictions in predictions_by_cell_code.items(): missing_scenario_ids = set(scenario_ids) - set(predictions.keys()) if missing_scenario_ids: print( - f"Not found: Missing predictions for {h3_index} for " + f"Not found: Missing predictions for {cell_code} for " f"{', '.join(missing_scenario_ids)}." ) return - predictions["h3_index"] = h3_index - # Output CSV will have the headers: h3_index,scenario_0,scenario_1... + predictions["cell_code"] = cell_code + # Output CSV will have the headers: cell_code,scenario_0,scenario_1... writer.writerow(predictions) @@ -232,7 +238,7 @@ def _files_complete( def _get_file_content(bucket: storage.Bucket, object_name: str) -> list[dict]: """Gets the content from a Blob. - Assumes Blob content is in CSV format with headers h3_index,prediction... + Assumes Blob content is in CSV format with headers cell_code,prediction... Args: bucket: The GCS bucket the Blob is in. diff --git a/cloud_functions/climateiq_merge_scenario_predictions_cf/main_test.py b/cloud_functions/climateiq_merge_scenario_predictions_cf/main_test.py index a79f2f2..40f328a 100644 --- a/cloud_functions/climateiq_merge_scenario_predictions_cf/main_test.py +++ b/cloud_functions/climateiq_merge_scenario_predictions_cf/main_test.py @@ -55,8 +55,8 @@ def _create_chunk_file( h3_indices_to_predictions: dict[str, float], tmp_path: str ) -> str: rows = ["h3_index,prediction"] + [ - f"{h3_index},{prediction}" - for h3_index, prediction in h3_indices_to_predictions.items() + f"{cell_code},{prediction}" + for cell_code, prediction in h3_indices_to_predictions.items() ] with tempfile.NamedTemporaryFile("w+", dir=tmp_path, delete=False) as fd: fd.write("\n".join(rows)) @@ -133,10 +133,10 @@ def test_merge_scenario_predictions( main.merge_scenario_predictions(_create_pubsub_event()) expected_chunk0_contents = ( - "h3_index,scenario0,scenario1\n" "h300,0.0,1.0\n" "h301,0.01,1.01\n" + "cell_code,scenario0,scenario1\n" "h300,0.0,1.0\n" "h301,0.01,1.01\n" ) expected_chunk1_contents = ( - "h3_index,scenario0,scenario1\n" "h310,0.1,1.1\n" "h311,0.11,1.11\n" + "cell_code,scenario0,scenario1\n" "h310,0.1,1.1\n" "h311,0.11,1.11\n" ) with open(output_files["batch/flood/model/nyc/chunk0.csv"]) as fd: assert fd.read() == expected_chunk0_contents @@ -387,7 +387,7 @@ def test_merge_scenario_predictions_missing_chunk_prints_error( @mock.patch.object(storage, "Client", autospec=True) @mock.patch.object(firestore, "Client", autospec=True) -def test_merge_scenario_predictions_missing_scenarios_for_h3_index_prints_error( +def test_merge_scenario_predictions_missing_scenarios_for_cell_code_prints_error( mock_firestore_client, mock_storage_client, tmp_path ): _create_firestore_entries(mock_firestore_client, ["scenario0", "scenario1"], 2)