diff --git a/forecasttools/__init__.py b/forecasttools/__init__.py index e1e2dba..b522a0c 100644 --- a/forecasttools/__init__.py +++ b/forecasttools/__init__.py @@ -3,53 +3,55 @@ import arviz as az import polars as pl -from .daily_to_epiweekly import df_aggregate_to_epiweekly -from .idata_w_dates_to_df import ( +from forecasttools.daily_to_epiweekly import df_aggregate_to_epiweekly +from forecasttools.idata_w_dates_to_df import ( add_dates_as_coords_to_idata, idata_forecast_w_dates_to_df, ) -from .recode_locations import loc_abbr_to_flusight_code -from .to_flusight import get_flusight_table -from .trajectories_to_quantiles import trajectories_to_quantiles +from forecasttools.recode_locations import ( + loc_abbr_to_hubverse_code, + loc_hubverse_code_to_abbr, + location_lookup, + to_location_table_column, +) +from forecasttools.to_hubverse import get_hubverse_table +from forecasttools.trajectories_to_quantiles import trajectories_to_quantiles # location table (from Census data) -with importlib.resources.path( - __package__, "location_table.parquet" -) as data_path: - location_table = pl.read_parquet(data_path) +with importlib.resources.files(__package__).joinpath( + "location_table.parquet" +).open("rb") as f: + location_table = pl.read_parquet(f) # load example flusight submission -with importlib.resources.path( - __package__, - "example_flusight_submission.parquet", -) as data_path: - dtypes_d = {"location": pl.Utf8} - example_flusight_submission = pl.read_parquet(data_path) +with importlib.resources.files(__package__).joinpath( + "example_flusight_submission.parquet" +).open("rb") as f: + example_flusight_submission = pl.read_parquet(f) # load example fitting data for COVID (NHSN, as of 2024-09-26) -with importlib.resources.path( - __package__, "nhsn_hosp_COVID.parquet" -) as data_path: - nhsn_hosp_COVID = pl.read_parquet(data_path) +with importlib.resources.files(__package__).joinpath( + "nhsn_hosp_COVID.parquet" +).open("rb") as f: + nhsn_hosp_COVID = pl.read_parquet(f) # load example fitting data for influenza (NHSN, as of 2024-09-26) -with importlib.resources.path( - __package__, "nhsn_hosp_flu.parquet" -) as data_path: - nhsn_hosp_flu = pl.read_parquet(data_path) +with importlib.resources.files(__package__).joinpath( + "nhsn_hosp_flu.parquet" +).open("rb") as f: + nhsn_hosp_flu = pl.read_parquet(f) # load light idata NHSN influenza forecast wo dates (NHSN, as of 2024-09-26) -with importlib.resources.path( - __package__, - "example_flu_forecast_wo_dates.nc", -) as data_path: - nhsn_flu_forecast_wo_dates = az.from_netcdf(data_path) +with importlib.resources.files(__package__).joinpath( + "example_flu_forecast_wo_dates.nc" +).open("rb") as f: + nhsn_flu_forecast_wo_dates = az.from_netcdf(f) # load light idata NHSN influenza forecast w dates (NHSN, as of 2024-09-26) -with importlib.resources.path( - __package__, "example_flu_forecast_w_dates.nc" -) as data_path: - nhsn_flu_forecast_w_dates = az.from_netcdf(data_path) +with importlib.resources.files(__package__).joinpath( + "example_flu_forecast_w_dates.nc" +).open("rb") as f: + nhsn_flu_forecast_w_dates = az.from_netcdf(f) __all__ = [ @@ -63,6 +65,9 @@ "add_dates_as_coords_to_idata", "trajectories_to_quantiles", "df_aggregate_to_epiweekly", - "loc_abbr_to_flusight_code", - "get_flusight_table", + "loc_abbr_to_hubverse_code", + "loc_hubverse_code_to_abbr", + "to_location_table_column", + "location_lookup", + "get_hubverse_table", ] diff --git a/forecasttools/data.py b/forecasttools/data.py index 7547d14..3536482 100644 --- a/forecasttools/data.py +++ b/forecasttools/data.py @@ -77,15 +77,17 @@ def make_census_dataset( "long_name": ["United States"], } ) - jurisdictions = pl.read_csv(url, separator="|").select( + jurisdictions = pl.read_csv( + url, separator="|", schema_overrides={"STATE": pl.Utf8} + ).select( [ - pl.col("STATE").alias("location_code").cast(pl.Utf8), + pl.col("STATE").alias("location_code"), pl.col("STUSAB").alias("short_name"), pl.col("STATE_NAME").alias("long_name"), ] ) location_table = nation.vstack(jurisdictions) - location_table.write_csv(file_save_path) + location_table.write_parquet(file_save_path) print(f"The file {file_save_path} has been saved.") diff --git a/forecasttools/location_table.parquet b/forecasttools/location_table.parquet index 76568a2..b952674 100644 Binary files a/forecasttools/location_table.parquet and b/forecasttools/location_table.parquet differ diff --git a/forecasttools/recode_locations.py b/forecasttools/recode_locations.py index 885d5d7..7db5268 100644 --- a/forecasttools/recode_locations.py +++ b/forecasttools/recode_locations.py @@ -1,7 +1,7 @@ """ -Functions to work with recoding columns -containing US jurisdiction location codes -and abbreviations. +Functions to work with recoding location +columns containing US jurisdiction location +codes or two-letter abbreviations. """ import polars as pl @@ -9,20 +9,23 @@ import forecasttools -def loc_abbr_to_flusight_code( +def loc_abbr_to_hubverse_code( df: pl.DataFrame, location_col: str ) -> pl.DataFrame: """ - Takes the location columns of a Polars - dataframe and recodes it to FluSight - location codes. - + Takes the location column of a Polars + dataframe (formatted as US two-letter + jurisdictional abbreviations) and recodes + it to hubverse location codes using + location_table, which is a Polars + dataframe contained in forecasttools. Parameters ---------- df A Polars dataframe with a location - column. + column consisting of US + jurisdictional abbreviations. location_col The name of the dataframe's location column. @@ -30,15 +33,208 @@ def loc_abbr_to_flusight_code( Returns ------- pl.DataFrame - A recoded locations dataframe. + A Polars dataframe with the location + column formatted as hubverse location + codes. """ - # get location table + # check inputted variable types + if not isinstance(df, pl.DataFrame): + raise TypeError(f"Expected a Polars DataFrame; got {type(df)}.") + if not isinstance(location_col, str): + raise TypeError( + f"Expected a string for location_col; got {type(location_col)}." + ) + # check if dataframe entered is empty + if df.is_empty(): + raise ValueError(f"The dataframe {df} is empty.") + # check if the location column exists + # in the inputted dataframe + if location_col not in df.columns: + raise ValueError( + f"Column '{location_col}' not found in the dataframe; got {df.columns}." + ) + # get location table from forecasttools loc_table = forecasttools.location_table - # recode and replaced existing loc abbrs with loc codes + # check if values in location_col are a + # subset of short_name in location table + location_values = set(df[location_col].to_list()) + valid_values = set(loc_table["short_name"].to_list()) + difference = location_values.difference(valid_values) + if difference: + raise ValueError( + f"The following values in '{location_col}') are not valid jurisdictional codes: {difference}." + ) + # recode existing location abbreviations + # with location codes loc_recoded_df = df.with_columns( - location=pl.col("location").replace( + pl.col(location_col).replace( old=loc_table["short_name"], new=loc_table["location_code"], ) ) return loc_recoded_df + + +def loc_hubverse_code_to_abbr( + df: pl.DataFrame, location_col: str +) -> pl.DataFrame: + """ + Takes the location columns of a Polars + dataframe (formatted as hubverse codes for + US two-letter jurisdictions) and recodes + it to US jurisdictional abbreviations, + using location_table, which is a Polars + dataframe contained in forecasttools. + + Parameters + ---------- + df + A Polars dataframe with a location + column consisting of US + jurisdictional hubverse codes. + location_col + The name of the dataframe's location + column. + + Returns + ------- + pl.DataFrame + A Polars dataframe with the location + column formatted as US two-letter + jurisdictional abbreviations. + """ + # check inputted variable types + if not isinstance(df, pl.DataFrame): + raise TypeError(f"Expected a Polars DataFrame; got {type(df)}.") + if not isinstance(location_col, str): + raise TypeError( + f"Expected a string for location_col; got {type(location_col)}." + ) + # check if dataframe entered is empty + if df.is_empty(): + raise ValueError(f"The dataframe {df} is empty.") + # check if the location column exists + # in the inputted dataframe + if location_col not in df.columns: + raise ValueError( + f"Column '{location_col}' not found in the dataframe; got {df.columns}." + ) + # get location table from forecasttools + loc_table = forecasttools.location_table + # check if values in location_col are a + # subset of location_code in location table + location_values = set(df[location_col].to_list()) + valid_values = set(loc_table["location_code"].to_list()) + difference = location_values.difference(valid_values) + if difference: + raise ValueError( + f"Some values in {difference} (in col '{location_col}') are not valid jurisdictional codes." + ) + # recode existing location codes with + # with location abbreviations + loc_recoded_df = df.with_columns( + pl.col(location_col).replace( + old=loc_table["location_code"], new=loc_table["short_name"] + ) + ) + return loc_recoded_df + + +def to_location_table_column(location_format: str) -> str: + """ + Maps a location format string to the + corresponding column name in the hubserve + location table. For example, "hubverse" + maps to "location_code" in forecasttool's + location_table. + + Parameters + ---------- + location_format + The format string ("abbr", + "hubverse", or "long_name"). + + Returns + ------- + str + Returns the corresponding column name + from the location table. + """ + # check inputted variable type + assert isinstance( + location_format, str + ), f"Expected a string; got {type(location_format)}." + # return proper column name from input format + col_dict = { + "abbr": "short_name", + "hubverse": "location_code", + "long_name": "long_name", + } + col = col_dict.get(location_format) + if col is None: + raise KeyError( + f"Unknown location format {location_format}. Expected one of:\n{col_dict.keys()}." + ) + return col + + +def location_lookup( + location_vector: list[str], location_format: str +) -> pl.DataFrame: + """ + Look up rows of the hubverse location + table corresponding to the entries + of a given location vector and format. + Retrieves the rows from location_table + in the forecasttools package + corresponding to a given vector of + location identifiers, with possible + repeats. + + Parameters + ---------- + location_vector + A list of location values. + + location_format + The format in which the location + vector is coded. Permitted formats + are: 'abbr', US two-letter + jurisdictional abbreviation; + 'hubverse', legacy 2-digit FIPS code + for states and territories; 'US' for + the USA as a whole; 'long_name', + full English name for the + jurisdiction. + + Returns + ------- + pl.DataFrame + Rows from location_table that match + the location vector, with repeats + possible. + """ + # check inputted variable types + if not isinstance(location_vector, list): + raise TypeError(f"Expected a list; got {type(location_vector)}.") + if not all(isinstance(loc, str) for loc in location_vector): + raise TypeError("All elements in location_vector must be of type str.") + if not isinstance(location_format, str): + raise TypeError(f"Expected a string; got {type(location_format)}.") + valid_formats = ["abbr", "hubverse", "long_name"] + if location_format not in valid_formats: + raise ValueError( + f"Invalid location format '{location_format}'. Expected one of: {valid_formats}." + ) + # check that location vector not empty + if not location_vector: + raise ValueError("The location_vector is empty.") + # get the join key based on the location format + join_key = forecasttools.to_location_table_column(location_format) + # create a dataframe for the location + # vector with the column cast as string + locs_df = pl.DataFrame({join_key: [str(loc) for loc in location_vector]}) + # inner join with the location_table + # based on the join key + locs = locs_df.join(forecasttools.location_table, on=join_key, how="inner") + return locs diff --git a/forecasttools/to_flusight.py b/forecasttools/to_hubverse.py similarity index 94% rename from forecasttools/to_flusight.py rename to forecasttools/to_hubverse.py index e5dd6c6..ab623c0 100644 --- a/forecasttools/to_flusight.py +++ b/forecasttools/to_hubverse.py @@ -1,6 +1,6 @@ """ Takes epiweekly quantilized Polars dataframe -and performs final conversion to the FluSight +and performs final conversion to the hubverse formatted output. """ @@ -10,12 +10,12 @@ import polars as pl -def get_flusight_target_end_dates( +def get_hubverse_target_end_dates( reference_date: str, horizons: list[str] | None = None, ) -> pl.DataFrame: """ - Generates remaining FluSight format + Generates remaining hubverse format columns from a reference date for use in a epiweekly quantilized dataframe. @@ -34,7 +34,7 @@ def get_flusight_target_end_dates( ------- pl.DataFrame A dataframe of columns necessary for - the FluSight submission. + the hubverse submission. """ # set default horizons in case of no specification if horizons is None: @@ -72,7 +72,7 @@ def get_flusight_target_end_dates( return data_df -def get_flusight_table( +def get_hubverse_table( quantile_forecasts: pl.DataFrame, reference_date: str, quantile_value_col: str = "quantile_value", @@ -85,7 +85,7 @@ def get_flusight_table( ) -> pl.DataFrame: """ Takes epiweekly quantilized Polars dataframe - and adds target ends dates for FluSight + and adds target ends dates for hubverse formatted output dataframe. Parameters @@ -128,7 +128,7 @@ def get_flusight_table( Returns ------- pl.DataFrame - A flusight formatted dataframe. + A hubverse formatted dataframe. """ # default horizons and locations if horizons is None: @@ -136,7 +136,7 @@ def get_flusight_table( if excluded_locations is None: excluded_locations = ["60", "78"] # get target end dates - targets = get_flusight_target_end_dates(reference_date, horizons=horizons) + targets = get_hubverse_target_end_dates(reference_date, horizons=horizons) # filter and select relevant columns quants = quantile_forecasts.select( [ diff --git a/notebooks/flusight_from_idata.qmd b/notebooks/flusight_from_idata.qmd index 2871bfd..6444891 100644 --- a/notebooks/flusight_from_idata.qmd +++ b/notebooks/flusight_from_idata.qmd @@ -223,7 +223,7 @@ Recode locations: ```{python} -forecast_df_recoded = forecasttools.loc_abbr_to_flusight_code( +forecast_df_recoded = forecasttools.loc_abbr_to_hubverse_code( df=forecast_df, location_col="location") forecast_df_recoded ``` @@ -231,7 +231,7 @@ forecast_df_recoded Format to FluSight: ```{python} -flusight_output = forecasttools.get_flusight_table( +flusight_output = forecasttools.get_hubverse_table( quantile_forecasts=forecast_df_recoded, quantile_value_col="quantile_value", quantile_level_col="quantile_level", diff --git a/tests/test_recoding_locations.py b/tests/test_recoding_locations.py new file mode 100644 index 0000000..37ffe7f --- /dev/null +++ b/tests/test_recoding_locations.py @@ -0,0 +1,199 @@ +""" +Test file for functions contained +within recode_locations.py +""" + +import polars as pl +import pytest + +import forecasttools + + +@pytest.mark.parametrize( + "function, df, location_col, expected_output", + [ + ( + forecasttools.loc_abbr_to_hubverse_code, + pl.DataFrame({"location": ["AL", "AK", "CA", "TX", "US"]}), + "location", + ["01", "02", "06", "48", "US"], + ), + ( + forecasttools.loc_hubverse_code_to_abbr, + pl.DataFrame({"location": ["01", "02", "06", "48", "US"]}), + "location", + ["AL", "AK", "CA", "TX", "US"], + ), + ], +) +def test_recode_valid_location_correct_input( + function, df, location_col, expected_output +): + """ + Test both recode functions (loc_abbr_to_hubverse_code + and loc_hubverse_code_to_abbr) for valid + location code and abbreviation output. + """ + df_w_loc_recoded = function(df=df, location_col=location_col) + loc_output = df_w_loc_recoded["location"].to_list() + assert ( + loc_output == expected_output + ), f"Expected {expected_output}, Got: {loc_output}" + + +@pytest.mark.parametrize( + "function, df, location_col, expected_exception", + [ + ( + forecasttools.loc_abbr_to_hubverse_code, + "not_a_dataframe", # not a dataframe type error + "location_col", + TypeError, + ), + ( + forecasttools.loc_abbr_to_hubverse_code, + pl.DataFrame({"location": ["AL", "AK"]}), + 123, # location column type failure + TypeError, + ), + ( + forecasttools.loc_abbr_to_hubverse_code, + pl.DataFrame(), + "location", # empty df failure + ValueError, + ), + ( + forecasttools.loc_abbr_to_hubverse_code, + pl.DataFrame({"location": ["AL", "AK"]}), + "non_existent_col", # location column name failure + ValueError, + ), + ( + forecasttools.loc_abbr_to_hubverse_code, + pl.DataFrame({"location": ["XX"]}), # abbr value failure + "location", + ValueError, + ), + ( + forecasttools.loc_hubverse_code_to_abbr, + "not_a_dataframe", # not a dataframe type error + "location_col", + TypeError, + ), + ( + forecasttools.loc_hubverse_code_to_abbr, + pl.DataFrame({"location": ["01", "02"]}), + 123, # location column type failure + TypeError, + ), + ( + forecasttools.loc_hubverse_code_to_abbr, + pl.DataFrame(), + "location", # empty df failure + ValueError, + ), + ( + forecasttools.loc_hubverse_code_to_abbr, + pl.DataFrame({"location": ["01", "02"]}), + "non_existent_col", # location column name failure + ValueError, + ), + ( + forecasttools.loc_hubverse_code_to_abbr, + pl.DataFrame({"location": ["99"]}), # code value failure + "location", + ValueError, + ), + ], +) +def test_loc_conversation_funcs_invalid_input( + function, df, location_col, expected_exception +): + """ + Test that loc_hubverse_code_to_abbr and + loc_abbr_to_hubverse_code handle type + errors for the dataframe and location + column name, value errors for the + location entries, and value errors if the + dataframe is empty. + """ + with pytest.raises(expected_exception): + function(df, location_col) + + +@pytest.mark.parametrize( + "location_format, expected_column", + [ + ("abbr", "short_name"), + ("hubverse", "location_code"), + ("long_name", "long_name"), + ], +) +def test_to_location_table_column_correct_input( + location_format, expected_column +): + """ + Test to_location_table_column for + expected column names + when given different location formats. + """ + result_column = forecasttools.to_location_table_column(location_format) + assert ( + result_column == expected_column + ), f"Expected column '{expected_column}' for format '{location_format}', but got '{result_column}'" + + +@pytest.mark.parametrize( + "location_format, expected_exception", + [ + (123, AssertionError), # invalid location type + ("unknown_format", KeyError), # bad location name + ], +) +def test_to_location_table_column_exception_handling( + location_format, expected_exception +): + """ + Test to_location_table_column for + exception handling. + """ + with pytest.raises(expected_exception): + forecasttools.to_location_table_column(location_format) + + +@pytest.mark.parametrize( + "location_vector, location_format, expected_exception", + [ + ("invalid_string", "abbr", TypeError), # invalid location vec type + ([1, 2, 3], "abbr", TypeError), # non-string elts in location vec + ( + ["AL", "CA"], + 123, + TypeError, + ), # invalid location format type (not str) + ( + ["AL", "CA"], + "invalid_format", + ValueError, + ), # invalid location_format value (not one of valid) + ([], "abbr", ValueError), # empty location_vector (edge) + (["AL", "CA"], "abbr", None), # valid inputs (expected no exception) + ], +) +def test_location_lookup_exceptions( + location_vector, location_format, expected_exception +): + """ + Test location_lookup for exception handling + and input validation. + """ + if expected_exception: + with pytest.raises(expected_exception): + forecasttools.location_lookup(location_vector, location_format) + else: + result = forecasttools.location_lookup( + location_vector, location_format + ) + assert isinstance( + result, pl.DataFrame + ), "Expected a Polars DataFrame as output."