Skip to content

Commit

Permalink
Add Remaining Location Code, Abbreviation, And Table Utilities (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 authored Nov 12, 2024
1 parent 4de2b41 commit 31e95cf
Show file tree
Hide file tree
Showing 7 changed files with 462 additions and 60 deletions.
73 changes: 39 additions & 34 deletions forecasttools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,55 @@
import arviz as az
import polars as pl

from .daily_to_epiweekly import df_aggregate_to_epiweekly
from .idata_w_dates_to_df import (
from forecasttools.daily_to_epiweekly import df_aggregate_to_epiweekly
from forecasttools.idata_w_dates_to_df import (
add_dates_as_coords_to_idata,
idata_forecast_w_dates_to_df,
)
from .recode_locations import loc_abbr_to_flusight_code
from .to_flusight import get_flusight_table
from .trajectories_to_quantiles import trajectories_to_quantiles
from forecasttools.recode_locations import (
loc_abbr_to_hubverse_code,
loc_hubverse_code_to_abbr,
location_lookup,
to_location_table_column,
)
from forecasttools.to_hubverse import get_hubverse_table
from forecasttools.trajectories_to_quantiles import trajectories_to_quantiles

# location table (from Census data)
with importlib.resources.path(
__package__, "location_table.parquet"
) as data_path:
location_table = pl.read_parquet(data_path)
with importlib.resources.files(__package__).joinpath(
"location_table.parquet"
).open("rb") as f:
location_table = pl.read_parquet(f)

# load example flusight submission
with importlib.resources.path(
__package__,
"example_flusight_submission.parquet",
) as data_path:
dtypes_d = {"location": pl.Utf8}
example_flusight_submission = pl.read_parquet(data_path)
with importlib.resources.files(__package__).joinpath(
"example_flusight_submission.parquet"
).open("rb") as f:
example_flusight_submission = pl.read_parquet(f)

# load example fitting data for COVID (NHSN, as of 2024-09-26)
with importlib.resources.path(
__package__, "nhsn_hosp_COVID.parquet"
) as data_path:
nhsn_hosp_COVID = pl.read_parquet(data_path)
with importlib.resources.files(__package__).joinpath(
"nhsn_hosp_COVID.parquet"
).open("rb") as f:
nhsn_hosp_COVID = pl.read_parquet(f)

# load example fitting data for influenza (NHSN, as of 2024-09-26)
with importlib.resources.path(
__package__, "nhsn_hosp_flu.parquet"
) as data_path:
nhsn_hosp_flu = pl.read_parquet(data_path)
with importlib.resources.files(__package__).joinpath(
"nhsn_hosp_flu.parquet"
).open("rb") as f:
nhsn_hosp_flu = pl.read_parquet(f)

# load light idata NHSN influenza forecast wo dates (NHSN, as of 2024-09-26)
with importlib.resources.path(
__package__,
"example_flu_forecast_wo_dates.nc",
) as data_path:
nhsn_flu_forecast_wo_dates = az.from_netcdf(data_path)
with importlib.resources.files(__package__).joinpath(
"example_flu_forecast_wo_dates.nc"
).open("rb") as f:
nhsn_flu_forecast_wo_dates = az.from_netcdf(f)

# load light idata NHSN influenza forecast w dates (NHSN, as of 2024-09-26)
with importlib.resources.path(
__package__, "example_flu_forecast_w_dates.nc"
) as data_path:
nhsn_flu_forecast_w_dates = az.from_netcdf(data_path)
with importlib.resources.files(__package__).joinpath(
"example_flu_forecast_w_dates.nc"
).open("rb") as f:
nhsn_flu_forecast_w_dates = az.from_netcdf(f)


__all__ = [
Expand All @@ -63,6 +65,9 @@
"add_dates_as_coords_to_idata",
"trajectories_to_quantiles",
"df_aggregate_to_epiweekly",
"loc_abbr_to_flusight_code",
"get_flusight_table",
"loc_abbr_to_hubverse_code",
"loc_hubverse_code_to_abbr",
"to_location_table_column",
"location_lookup",
"get_hubverse_table",
]
8 changes: 5 additions & 3 deletions forecasttools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,17 @@ def make_census_dataset(
"long_name": ["United States"],
}
)
jurisdictions = pl.read_csv(url, separator="|").select(
jurisdictions = pl.read_csv(
url, separator="|", schema_overrides={"STATE": pl.Utf8}
).select(
[
pl.col("STATE").alias("location_code").cast(pl.Utf8),
pl.col("STATE").alias("location_code"),
pl.col("STUSAB").alias("short_name"),
pl.col("STATE_NAME").alias("long_name"),
]
)
location_table = nation.vstack(jurisdictions)
location_table.write_csv(file_save_path)
location_table.write_parquet(file_save_path)
print(f"The file {file_save_path} has been saved.")


Expand Down
Binary file modified forecasttools/location_table.parquet
Binary file not shown.
222 changes: 209 additions & 13 deletions forecasttools/recode_locations.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,240 @@
"""
Functions to work with recoding columns
containing US jurisdiction location codes
and abbreviations.
Functions to work with recoding location
columns containing US jurisdiction location
codes or two-letter abbreviations.
"""

import polars as pl

import forecasttools


def loc_abbr_to_flusight_code(
def loc_abbr_to_hubverse_code(
df: pl.DataFrame, location_col: str
) -> pl.DataFrame:
"""
Takes the location columns of a Polars
dataframe and recodes it to FluSight
location codes.
Takes the location column of a Polars
dataframe (formatted as US two-letter
jurisdictional abbreviations) and recodes
it to hubverse location codes using
location_table, which is a Polars
dataframe contained in forecasttools.
Parameters
----------
df
A Polars dataframe with a location
column.
column consisting of US
jurisdictional abbreviations.
location_col
The name of the dataframe's location
column.
Returns
-------
pl.DataFrame
A recoded locations dataframe.
A Polars dataframe with the location
column formatted as hubverse location
codes.
"""
# get location table
# check inputted variable types
if not isinstance(df, pl.DataFrame):
raise TypeError(f"Expected a Polars DataFrame; got {type(df)}.")
if not isinstance(location_col, str):
raise TypeError(
f"Expected a string for location_col; got {type(location_col)}."
)
# check if dataframe entered is empty
if df.is_empty():
raise ValueError(f"The dataframe {df} is empty.")
# check if the location column exists
# in the inputted dataframe
if location_col not in df.columns:
raise ValueError(
f"Column '{location_col}' not found in the dataframe; got {df.columns}."
)
# get location table from forecasttools
loc_table = forecasttools.location_table
# recode and replaced existing loc abbrs with loc codes
# check if values in location_col are a
# subset of short_name in location table
location_values = set(df[location_col].to_list())
valid_values = set(loc_table["short_name"].to_list())
difference = location_values.difference(valid_values)
if difference:
raise ValueError(
f"The following values in '{location_col}') are not valid jurisdictional codes: {difference}."
)
# recode existing location abbreviations
# with location codes
loc_recoded_df = df.with_columns(
location=pl.col("location").replace(
pl.col(location_col).replace(
old=loc_table["short_name"],
new=loc_table["location_code"],
)
)
return loc_recoded_df


def loc_hubverse_code_to_abbr(
df: pl.DataFrame, location_col: str
) -> pl.DataFrame:
"""
Takes the location columns of a Polars
dataframe (formatted as hubverse codes for
US two-letter jurisdictions) and recodes
it to US jurisdictional abbreviations,
using location_table, which is a Polars
dataframe contained in forecasttools.
Parameters
----------
df
A Polars dataframe with a location
column consisting of US
jurisdictional hubverse codes.
location_col
The name of the dataframe's location
column.
Returns
-------
pl.DataFrame
A Polars dataframe with the location
column formatted as US two-letter
jurisdictional abbreviations.
"""
# check inputted variable types
if not isinstance(df, pl.DataFrame):
raise TypeError(f"Expected a Polars DataFrame; got {type(df)}.")
if not isinstance(location_col, str):
raise TypeError(
f"Expected a string for location_col; got {type(location_col)}."
)
# check if dataframe entered is empty
if df.is_empty():
raise ValueError(f"The dataframe {df} is empty.")
# check if the location column exists
# in the inputted dataframe
if location_col not in df.columns:
raise ValueError(
f"Column '{location_col}' not found in the dataframe; got {df.columns}."
)
# get location table from forecasttools
loc_table = forecasttools.location_table
# check if values in location_col are a
# subset of location_code in location table
location_values = set(df[location_col].to_list())
valid_values = set(loc_table["location_code"].to_list())
difference = location_values.difference(valid_values)
if difference:
raise ValueError(
f"Some values in {difference} (in col '{location_col}') are not valid jurisdictional codes."
)
# recode existing location codes with
# with location abbreviations
loc_recoded_df = df.with_columns(
pl.col(location_col).replace(
old=loc_table["location_code"], new=loc_table["short_name"]
)
)
return loc_recoded_df


def to_location_table_column(location_format: str) -> str:
"""
Maps a location format string to the
corresponding column name in the hubserve
location table. For example, "hubverse"
maps to "location_code" in forecasttool's
location_table.
Parameters
----------
location_format
The format string ("abbr",
"hubverse", or "long_name").
Returns
-------
str
Returns the corresponding column name
from the location table.
"""
# check inputted variable type
assert isinstance(
location_format, str
), f"Expected a string; got {type(location_format)}."
# return proper column name from input format
col_dict = {
"abbr": "short_name",
"hubverse": "location_code",
"long_name": "long_name",
}
col = col_dict.get(location_format)
if col is None:
raise KeyError(
f"Unknown location format {location_format}. Expected one of:\n{col_dict.keys()}."
)
return col


def location_lookup(
location_vector: list[str], location_format: str
) -> pl.DataFrame:
"""
Look up rows of the hubverse location
table corresponding to the entries
of a given location vector and format.
Retrieves the rows from location_table
in the forecasttools package
corresponding to a given vector of
location identifiers, with possible
repeats.
Parameters
----------
location_vector
A list of location values.
location_format
The format in which the location
vector is coded. Permitted formats
are: 'abbr', US two-letter
jurisdictional abbreviation;
'hubverse', legacy 2-digit FIPS code
for states and territories; 'US' for
the USA as a whole; 'long_name',
full English name for the
jurisdiction.
Returns
-------
pl.DataFrame
Rows from location_table that match
the location vector, with repeats
possible.
"""
# check inputted variable types
if not isinstance(location_vector, list):
raise TypeError(f"Expected a list; got {type(location_vector)}.")
if not all(isinstance(loc, str) for loc in location_vector):
raise TypeError("All elements in location_vector must be of type str.")
if not isinstance(location_format, str):
raise TypeError(f"Expected a string; got {type(location_format)}.")
valid_formats = ["abbr", "hubverse", "long_name"]
if location_format not in valid_formats:
raise ValueError(
f"Invalid location format '{location_format}'. Expected one of: {valid_formats}."
)
# check that location vector not empty
if not location_vector:
raise ValueError("The location_vector is empty.")
# get the join key based on the location format
join_key = forecasttools.to_location_table_column(location_format)
# create a dataframe for the location
# vector with the column cast as string
locs_df = pl.DataFrame({join_key: [str(loc) for loc in location_vector]})
# inner join with the location_table
# based on the join key
locs = locs_df.join(forecasttools.location_table, on=join_key, how="inner")
return locs
Loading

0 comments on commit 31e95cf

Please sign in to comment.