Skip to content

Commit

Permalink
test: Further testing for plexos utils
Browse files Browse the repository at this point in the history
  • Loading branch information
pesap committed Oct 23, 2024
1 parent 9cb02be commit c19e4fe
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 58 deletions.
2 changes: 0 additions & 2 deletions src/r2x/models/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def variable_type(self) -> str | None:
def value_curve_type(self) -> str | None:
"""Create attribute that holds the class name."""
try:
if not attrgetter("variable.value_curve")(self):
return None
return type(attrgetter("variable.value_curve")(self)).__name__
except AttributeError:
return None
Expand Down
5 changes: 5 additions & 0 deletions src/r2x/parser/plexos.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def _construct_branches(self, default_model=MonitoredLine):
values="property_value",
aggregate_function="first",
)
if lines_pivot.is_empty():
logger.warning("No line objects found on the system.")
return

lines_pivot_memberships = self.db.get_memberships(
*lines_pivot["name"].to_list(), object_class=ClassEnum.Line
Expand Down Expand Up @@ -808,6 +811,8 @@ def _construct_interfaces(self, default_model=TransmissionInterface):

# Add lines memberships
lines = [line["name"] for line in self.system.to_records(MonitoredLine)]
if not lines:
return
lines_memberships = self.db.get_memberships(
*lines,
object_class=ClassEnum.Line,
Expand Down
168 changes: 112 additions & 56 deletions src/r2x/parser/plexos_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Compilation of functions used on the PLEXOS parser."""

# ruff: noqa

from datetime import datetime
import re
from datetime import datetime, timedelta
from enum import Enum
from typing import Any
from collections.abc import Sequence

from numpy._typing import NDArray
import pint
import polars as pl
import numpy as np
from loguru import logger
from infrasys import SingleTimeSeries

PLEXOS_ACTION_MAP = {
"×": np.multiply, # noqa
Expand All @@ -20,23 +22,23 @@
}


class DATAFILE_COLUMNS(Enum):
class DATAFILE_COLUMNS(Enum): # noqa: N801
"""Enum of possible Data file columns in Plexos."""

NV = ["name", "value"]
Y = ["year"]
PV = ["pattern", "value"]
TS_NPV = ["name", "pattern", "value"]
TS_NYV = ["name", "year", "value"]
TS_NDV = ["name", "DateTime", "value"]
TS_YMDP = ["year", "month", "day", "period"]
TS_YMDPV = ["year", "month", "day", "period", "value"]
TS_NYMDV = ["name", "year", "month", "day", "value"]
TS_NYMDPV = ["name", "year", "month", "day", "period", "value"]
TS_YM = ["year", "month"]
TS_MDP = ["month", "day", "period"]
TS_NMDP = ["name", "month", "day", "period"]
TS_YMDH = [
NV = ("name", "value")
Y = "year"
PV = ("pattern", "value")
TS_NPV = ("name", "pattern", "value")
TS_NYV = ("name", "year", "value")
TS_NDV = ("name", "DateTime", "value")
TS_YMDP = ("year", "month", "day", "period")
TS_YMDPV = ("year", "month", "day", "period", "value")
TS_NYMDV = ("name", "year", "month", "day", "value")
TS_NYMDPV = ("name", "year", "month", "day", "period", "value")
TS_YM = ("year", "month")
TS_MDP = ("month", "day", "period")
TS_NMDP = ("name", "month", "day", "period")
TS_YMDH = (
"year",
"month",
"day",
Expand Down Expand Up @@ -64,8 +66,8 @@ class DATAFILE_COLUMNS(Enum):
"22",
"23",
"24",
]
TS_NYMDH = [
)
TS_NYMDH = (
"name",
"year",
"month",
Expand Down Expand Up @@ -94,8 +96,8 @@ class DATAFILE_COLUMNS(Enum):
"22",
"23",
"24",
]
TS_NMDH = [
)
TS_NMDH = (
"name",
"month",
"day",
Expand Down Expand Up @@ -123,8 +125,8 @@ class DATAFILE_COLUMNS(Enum):
"22",
"23",
"24",
]
TS_NM = [
)
TS_NM = (
"name",
"m01",
"m02",
Expand All @@ -138,7 +140,7 @@ class DATAFILE_COLUMNS(Enum):
"m10",
"m11",
"m12",
]
)


def get_column_enum(columns: list[str]) -> DATAFILE_COLUMNS | None:
Expand Down Expand Up @@ -314,32 +316,80 @@ def parse_ts_nymdh(data_file):
return data_file


def parse_patterns(key):
def parse_patterns(key: str) -> list[tuple[str, list[int]]]:
"""Parse a key for time slice patterns (e.g., 'M1-3', 'H1-6') and return a list of tuples.
Parameters
----------
key : str
A string pattern representing time slices, such as months ('M1-12'), hours ('H1-24'),
weekdays ('W1-7'), and days of the month ('D1-31').
Returns
-------
List[tuple[str, List[int]]]
A list of tuples where the first element is the time slice type (e.g., 'M', 'H', 'W', 'D'),
and the second element is the list of integers representing the range of values for that time slice.
Raises
------
TypeError
If the input is not a string.
ValueError
If the ranges are invalid (e.g., 'M13', 'H25').
Examples
--------
>>> parse_patterns("M1-3")
[('M', [1, 2, 3])]
>>> parse_patterns("H1-6,H18-24")
[('H', [1, 2, 3, 4, 5, 6]), ('H', [18, 19, 20, 21, 22, 23, 24])]
>>> parse_patterns("W1,H1-6")
[('W', [1]), ('H', [1, 2, 3, 4, 5, 6])]
"""
if not isinstance(key, str):
raise TypeError(f"Expected 'key' to be a str, got {type(key).__name__}")

ranges = key.split(";")
month_list = []
pattern_list = []

for rng in ranges:
# Match ranges like 'M5-10' and single months like 'M1'
match = re.match(r"M(\d+)(?:-(\d+))?", rng)
if match:
start_month = int(match.group(1))
end_month = int(match.group(2)) if match.group(2) else start_month
# Generate the list of months from the range
month_list.extend(range(start_month, end_month + 1))
return month_list
time_slice_matches = re.finditer(r"([MWHD])(\d+)(?:-(\d+))?", rng)
for match in time_slice_matches:
time_slice_type = match.group(1)
start_value = int(match.group(2))
end_value = int(match.group(3)) if match.group(3) else start_value

# Validating ranges based on time slice type
if time_slice_type == "M" and not (1 <= start_value <= 12 and 1 <= end_value <= 12):
raise ValueError(f"Invalid month range: {start_value}-{end_value}")
if time_slice_type == "H" and not (1 <= start_value <= 24 and 1 <= end_value <= 24):
raise ValueError(f"Invalid hour range: {start_value}-{end_value}")
if time_slice_type == "W" and not (1 <= start_value <= 7 and 1 <= end_value <= 7):
raise ValueError(f"Invalid weekday range: {start_value}-{end_value}")
if time_slice_type == "D" and not (1 <= start_value <= 31 and 1 <= end_value <= 31):
raise ValueError(f"Invalid day of month range: {start_value}-{end_value}")

pattern_list.append((time_slice_type, list(range(start_value, end_value + 1))))

return pattern_list


def time_slice_handler(
records: list[dict[str, Any]],
hourly_time_index: pl.DataFrame,
hourly_time_index: pl.DataFrame | NDArray[np.datetime64] | Sequence[datetime],
pattern_key: str = "pattern",
) -> np.ndarray:
"""Deconstruct a dict of time slices and return a NumPy array representing a time series.
Parameters
----------
records : List[dict[str, Any]]
records : dist[str, Any]
A list of dictionaries containing timeslice records.
hourly_time_index : pl.DataFrame
hourly_time_index : pl.DataFrame | NDArray[np.datetime64] | Sequence[datetime]
Dataframe containing a 'datetime' column for hourly time index.
pattern_key : str, optional
Key used to extract patterns from records (default is 'pattern').
Expand All @@ -360,35 +410,41 @@ def time_slice_handler(
Examples
--------
>>> records = [{"pattern": "M1-2", "value": np.array([5])}, {"pattern": "M3", "value": np.array([10])}]
>>> datetime_values = [datetime(2024, 1, 1), datetime(2024, 2, 1), datetime(2024, 3, 1)]
>>> hourly_time_index = pl.DataFrame({"datetime": datetime_values})
>>> time_slice_handler(records, hourly_time_index)
array([5., 5., 10.])
>>> hourly_time_index = pl.DataFrame({"wrong_column": [1, 2, 3]}) # Raises ValueError
>>> from datetime import datetime, timedelta
>>> records = [{"pattern": "M1-2", "value": 200}, {"pattern": "M3-12", "value": 100}]
>>> start = datetime(year, 1, 1)
>>> end = datetime(year + 1, 1, 1)
>>> delta = timedelta(hours=1)
>>> datetime_index = tuple(start + i * delta for i in range((end - start) // delta))
>>> time_slice_handler(records, datetime_index)
"""
if not isinstance(hourly_time_index, pl.DataFrame):
raise TypeError(
f"Expected 'hourly_time_index' to be a polars DataFrame, got {type(hourly_time_index).__name__}"
)
if isinstance(hourly_time_index, pl.DataFrame):
hourly_time_index = hourly_time_index.to_numpy()

if not all(isinstance(record, dict) for record in records):
raise TypeError("All records must be dictionaries")

if not all(record[pattern_key].startswith("M") for record in records if pattern_key in record):
raise NotImplementedError("All records must contain valid month patterns starting with 'M'")

if "datetime" not in hourly_time_index.columns:
raise ValueError("Hourly time index does not have 'datetime' column")

hourly_time_index = hourly_time_index["datetime"]
if isinstance(hourly_time_index, np.ndarray):
hourly_time_index = hourly_time_index.astype(datetime).flatten().tolist()

months = np.array([dt.month for dt in hourly_time_index])
# hours = np.array([dt.hour for dt in hourly_time_index])
month_datetime_series = np.zeros(len(hourly_time_index), dtype=float)

for record in records:
months_in_key = parse_patterns(record[pattern_key])
for month in months_in_key:
month_datetime_series[months == month] = record["value"].magnitude
patterns = parse_patterns(record[pattern_key])
for pattern in patterns:
match pattern[0]:
case "M":
month_datetime_series[np.isin(months, pattern[1])] = (
record["value"].magnitude
if isinstance(record["value"], pint.Quantity)
else record["value"]
)
case _:
raise NotImplementedError

return month_datetime_series
5 changes: 5 additions & 0 deletions tests/test_plexos_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@ def test_raise_if_no_map_provided(tmp_path, data_folder):
)
with pytest.raises(ParserError):
_ = get_parser_data(scenario, parser_class=PlexosParser)


def test_parser_system(plexos_parser_instance):
...
# plexos_parser = plexos_parser_instance.build_system()
94 changes: 94 additions & 0 deletions tests/test_plexos_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest
from datetime import datetime, timedelta
import polars as pl
from r2x.parser.plexos_utils import DATAFILE_COLUMNS, get_column_enum, time_slice_handler


def test_get_column_enum():
"""Test multiple cases for get_column_enum function."""
# Case 1: Exact match for NV
columns = ["name", "value"]
assert get_column_enum(columns) == DATAFILE_COLUMNS.NV

# Case 2: Exact match for TS_YMDPV
columns = ["year", "month", "day", "period", "value"]
assert get_column_enum(columns) == DATAFILE_COLUMNS.TS_YMDPV

# Case 3: Subset match for TS_YM (with an extra column)
columns = ["year", "month", "extra"]
assert get_column_enum(columns) == DATAFILE_COLUMNS.TS_YM

# Case 4: No match (completely unrelated columns)
columns = ["random", "columns"]
assert get_column_enum(columns) is None

# Case 5: Partial match for NV (extra column in input)
columns = ["name", "value", "extra"]
assert get_column_enum(columns) == DATAFILE_COLUMNS.NV

# Case 6: Exact match for TS_NMDH (large set of columns)
columns = [
"name",
"month",
"day",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"10",
"11",
"12",
"13",
"14",
"15",
"16",
"17",
"18",
"19",
"20",
"21",
"22",
"23",
"24",
]
assert get_column_enum(columns) == DATAFILE_COLUMNS.TS_NMDH


def test_time_slice_handler():
records = [{"pattern": "M1-2", "value": 200}, {"pattern": "M3-12", "value": 100}]

year = 2012
hourly_time_index = pl.datetime_range(
datetime(year, 1, 1), datetime(year + 1, 1, 1), interval="1h", eager=True, closed="left"
).to_frame("datetime")
result_polars = time_slice_handler(records, hourly_time_index)

assert all(result_polars[:100] == 200)
assert all(result_polars[-100:] == 100)

start = datetime(year, 1, 1)
end = datetime(year + 1, 1, 1)
delta = timedelta(hours=1)
datetime_index = tuple(start + i * delta for i in range((end - start) // delta))
result_datetime = time_slice_handler(records, datetime_index)
assert all(result_datetime == result_polars)


def test_time_slice_handler_raises():
year = 2020
start = datetime(year, 1, 1)
end = datetime(year + 1, 1, 1)
delta = timedelta(hours=1)
datetime_index = tuple(start + i * delta for i in range((end - start) // delta))
records = [{"pattern": "M1-2", "value": 200}, {"pattern": "M3-12", "value": 100}, [1, 2]]
with pytest.raises(TypeError):
_ = time_slice_handler(records, datetime_index)

records = [{"pattern": "H1-2", "value": 200}, {"pattern": "M3-12", "value": 100}]
with pytest.raises(NotImplementedError):
_ = time_slice_handler(records, datetime_index)

0 comments on commit c19e4fe

Please sign in to comment.