From 127fd5d82922bb63135070cef642046c5b9fe5c9 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Wed, 10 Apr 2024 16:49:08 +1200 Subject: [PATCH 1/6] Add tests for `task` Signed-off-by: Nathan McDougall --- labtech/tasks.py | 13 +-- tests/labtech/test_tasks.py | 181 +++++++++++++++++++++++++++++++++++- 2 files changed, 187 insertions(+), 7 deletions(-) diff --git a/labtech/tasks.py b/labtech/tasks.py index 1e33f38..fa97af5 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -19,6 +19,11 @@ class CacheDefault: CACHE_DEFAULT = CacheDefault() +_RESERVED_ATTRS = [ + '_lt', '_is_task', 'cache_key', 'result', '_results_map', '_set_results_map', + 'result_meta', '_set_result_meta', 'context', 'set_context', +] +"""Reserved attribute names for task types.""" def immutable_param_value(key: str, value: Any) -> Any: """Converts a parameter value to an immutable equivalent that is hashable.""" @@ -187,16 +192,12 @@ def run(self): def decorator(cls): nonlocal cache - reserved_attrs = [ - '_lt', '_is_task', 'cache_key', 'result', '_results_map', '_set_results_map', - 'result_meta', '_set_result_meta', 'context', 'set_context', - ] if not is_task_type(cls): - for reserved_attr in reserved_attrs: + for reserved_attr in _RESERVED_ATTRS: if hasattr(cls, reserved_attr): raise AttributeError(f"Task type already defines reserved attribute '{reserved_attr}'.") - post_init = getattr(cls, 'post_init', None) + post_init = getattr(cls, '__post_init__', None) cls.__post_init__ = _task_post_init cls = dataclass(frozen=True, eq=True, order=True)(cls) diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index 088c189..bebdf1e 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -1,10 +1,15 @@ +from dataclasses import FrozenInstanceError, dataclass import re from enum import Enum +from typing import Literal +import labtech.tasks import pytest from frozendict import frozendict +from labtech.cache import BaseCache, NullCache, PickleCache from labtech.exceptions import TaskError -from labtech.tasks import immutable_param_value +from labtech.tasks import _RESERVED_ATTRS, immutable_param_value +from labtech.types import ResultT, Storage, Task, TaskInfo class _BadObject: @@ -18,6 +23,180 @@ class _ExampleEnum(Enum): B = 2 +class BadCache(BaseCache): + """A pretend cache that returns fake values.""" + + KEY_PREFIX = "bad__" + + def save_result(self, storage: Storage, task: Task[ResultT], result: ResultT): + raise NotImplementedError + + def load_result(self, storage: Storage, task: Task[ResultT]) -> ResultT: + raise NotImplementedError + + +class TestTask: + def test_defaults(self) -> None: + @labtech.task + class SimpleTask: + def run(self) -> None: + return None + + task = SimpleTask() + task_info: TaskInfo = task._lt + + assert isinstance(task_info.cache, PickleCache) + assert task_info.max_parallel is None + assert task_info.mlflow_run is False + assert task_info.orig_post_init is None + + def test_null_cache(self) -> None: + @labtech.task(cache=None) + class SimpleTask: + def run(self) -> None: + return None + + task = SimpleTask() + task_info: TaskInfo = task._lt + assert isinstance(task_info.cache, NullCache) + + def test_nondefault_cache(self) -> None: + @labtech.task(cache=BadCache()) + class SimpleTask: + def run(self) -> None: + return None + + task = SimpleTask() + task_info: TaskInfo = task._lt + assert isinstance(task_info.cache, BadCache) + + @pytest.mark.parametrize("max_parallel", [None, 1, 2, 3]) + def test_max_parallel(self, max_parallel: int | None) -> None: + @labtech.task(max_parallel=max_parallel) + class SimpleTask: + def run(self) -> None: + pass + + task = SimpleTask() + task_info: TaskInfo = task._lt + assert task_info.max_parallel == max_parallel + + def test_mlflow_run(self) -> None: + @labtech.task(mlflow_run=True) + class SimpleTask: + def run(self) -> None: + pass + + task = SimpleTask() + task_info: TaskInfo = task._lt + assert task_info.mlflow_run is True + + def test_reserved_lt_attr(self) -> None: + match = re.escape("Task type already defines reserved attribute '_lt'.") + with pytest.raises(AttributeError, match=match): + + @labtech.task + class SimpleTask: + def run(self) -> None: + pass + + def _lt(self) -> None: + pass + + @pytest.mark.parametrize("badattr", _RESERVED_ATTRS) + def test_fail_reserved_attrs(self, badattr: str) -> None: + class SimpleTaskBase: + pass + + setattr(SimpleTaskBase, badattr, None) + + match = re.escape(f"Task type already defines reserved attribute '{badattr}'.") + with pytest.raises(AttributeError, match=match): + + @labtech.task(mlflow_run=True) + class SimpleTask(SimpleTaskBase): + def run(self) -> None: + pass + + def test_stored_post_init(self) -> None: + @labtech.task + class SimpleTask: + def __post_init__(self): + return "It's me!" + + def run(self) -> None: + pass + + task = SimpleTask() + task_info: TaskInfo = task._lt + assert task_info.orig_post_init is not None + assert task_info.orig_post_init(task) == "It's me!" + + def test_frozen(self) -> None: + @labtech.task + class SimpleTask: + a: int + + def run(self) -> None: + return None + + task = SimpleTask(a=1) + + # Check the dataclass is now frozen + with pytest.raises(FrozenInstanceError): + task.a = 2 + + def test_order(self) -> None: + @labtech.task + class SimpleTask: + a: int + b: str + + def run(self) -> None: + return None + + task1 = SimpleTask(a=1, b="hello") + task2 = SimpleTask(b="hello", a=2) + task3 = SimpleTask(a=1, b="zzz") + + assert task1 < task2 + assert task2 > task3 + assert task1 <= task3 + assert task1 != task2 + assert task1 == task1 + + def test_fail_no_run(self) -> None: + match = re.escape("Task type 'SimpleTask' must define a 'run' method") + with pytest.raises(NotImplementedError, match=match): + + @labtech.task + class SimpleTask: + pass + + def test_fail_noncallable_run(self) -> None: + match = re.escape("Task type 'SimpleTask' must define a 'run' method") + with pytest.raises(NotImplementedError, match=match): + + @labtech.task + class SimpleTask: + run: int + + def test_post_init_missing_dunder(self) -> None: + """This checks a bug is fixed where we were storing `post_init` instead of `__post_init__`.""" + + @labtech.task + class SimpleTask: + def post_init(self): + pass + + def run(self) -> None: + pass + + task = SimpleTask() + task_info: TaskInfo = task._lt + assert task_info.orig_post_init is None + + class TestImmutableParamValue: def test_empty_list(self) -> None: assert immutable_param_value("hello", []) == () From 9b98dd136f6fe692a111481eabc14da7615d5d7d Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Wed, 10 Apr 2024 16:53:07 +1200 Subject: [PATCH 2/6] Comply with previous API for post_init Signed-off-by: Nathan McDougall --- labtech/tasks.py | 4 ++-- tests/labtech/test_tasks.py | 22 +++++++++------------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/labtech/tasks.py b/labtech/tasks.py index fa97af5..dd52b1f 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -21,7 +21,7 @@ class CacheDefault: _RESERVED_ATTRS = [ '_lt', '_is_task', 'cache_key', 'result', '_results_map', '_set_results_map', - 'result_meta', '_set_result_meta', 'context', 'set_context', + 'result_meta', '_set_result_meta', 'context', 'set_context', '__post_init__', ] """Reserved attribute names for task types.""" @@ -197,7 +197,7 @@ def decorator(cls): if hasattr(cls, reserved_attr): raise AttributeError(f"Task type already defines reserved attribute '{reserved_attr}'.") - post_init = getattr(cls, '__post_init__', None) + post_init = getattr(cls, 'post_init', None) cls.__post_init__ = _task_post_init cls = dataclass(frozen=True, eq=True, order=True)(cls) diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index bebdf1e..df0ebf4 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -121,7 +121,7 @@ def run(self) -> None: def test_stored_post_init(self) -> None: @labtech.task class SimpleTask: - def __post_init__(self): + def post_init(self): return "It's me!" def run(self) -> None: @@ -182,19 +182,15 @@ class SimpleTask: run: int def test_post_init_missing_dunder(self) -> None: - """This checks a bug is fixed where we were storing `post_init` instead of `__post_init__`.""" - - @labtech.task - class SimpleTask: - def post_init(self): - pass - - def run(self) -> None: - pass + match = re.escape("Task type already defines reserved attribute '__post_init__'.") + with pytest.raises(AttributeError, match=match): + @labtech.task + class SimpleTask: + def __post_init__(self): + pass - task = SimpleTask() - task_info: TaskInfo = task._lt - assert task_info.orig_post_init is None + def run(self) -> None: + pass class TestImmutableParamValue: From 3b852448f086408b9422076de4a572c4de4fba66 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Wed, 10 Apr 2024 16:55:44 +1200 Subject: [PATCH 3/6] More robust test for _lt attribute case Signed-off-by: Nathan McDougall --- tests/labtech/test_tasks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index df0ebf4..6589f60 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -103,6 +103,9 @@ def run(self) -> None: def _lt(self) -> None: pass + def _is_task(self) -> None: + pass + @pytest.mark.parametrize("badattr", _RESERVED_ATTRS) def test_fail_reserved_attrs(self, badattr: str) -> None: class SimpleTaskBase: From 52912d7f9c82d949b7b39fe0672e723bf9428a5a Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Wed, 10 Apr 2024 17:00:15 +1200 Subject: [PATCH 4/6] Test inheritance Signed-off-by: Nathan McDougall --- tests/labtech/test_tasks.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index 6589f60..3dadda3 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -195,6 +195,19 @@ def __post_init__(self): def run(self) -> None: pass + def test_inheritance(self) -> None: + @labtech.task + class SimpleTask: + def run(self) -> None: + pass + + @labtech.task + class SubTask(SimpleTask): + def run(self) -> None: + pass + + # Check we don't get an error trying to do this. + SubTask() class TestImmutableParamValue: def test_empty_list(self) -> None: From fc862a894b77f1b433f67fd2d7ea12439496b4eb Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Wed, 10 Apr 2024 23:30:42 +1200 Subject: [PATCH 5/6] Linter compliance Signed-off-by: Nathan McDougall --- labtech/tasks.py | 1 + tests/labtech/test_tasks.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/labtech/tasks.py b/labtech/tasks.py index dd52b1f..4e9dd53 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -25,6 +25,7 @@ class CacheDefault: ] """Reserved attribute names for task types.""" + def immutable_param_value(key: str, value: Any) -> Any: """Converts a parameter value to an immutable equivalent that is hashable.""" if isinstance(value, list) or isinstance(value, tuple): diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index 3dadda3..67c7979 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -1,7 +1,6 @@ -from dataclasses import FrozenInstanceError, dataclass import re +from dataclasses import FrozenInstanceError from enum import Enum -from typing import Literal import labtech.tasks import pytest @@ -185,8 +184,11 @@ class SimpleTask: run: int def test_post_init_missing_dunder(self) -> None: - match = re.escape("Task type already defines reserved attribute '__post_init__'.") + match = re.escape( + "Task type already defines reserved attribute '__post_init__'." + ) with pytest.raises(AttributeError, match=match): + @labtech.task class SimpleTask: def __post_init__(self): @@ -209,6 +211,7 @@ def run(self) -> None: # Check we don't get an error trying to do this. SubTask() + class TestImmutableParamValue: def test_empty_list(self) -> None: assert immutable_param_value("hello", []) == () From 7a7b69074958c04b2fd2935fa3f851437013cde3 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Wed, 24 Apr 2024 21:24:42 +1200 Subject: [PATCH 6/6] Make task detection functions more robust --- labtech/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/labtech/types.py b/labtech/types.py index 358f18d..83e8c32 100644 --- a/labtech/types.py +++ b/labtech/types.py @@ -86,12 +86,12 @@ def run(self): def is_task_type(cls): """Returns `True` if the given `cls` is a class decorated with [`labtech.task`][labtech.task].""" - return isclass(cls) and hasattr(cls, '_lt') + return isclass(cls) and isinstance(getattr(cls, '_lt', None), TaskInfo) def is_task(obj): """Returns `True` if the given `obj` is an instance of a task class.""" - return hasattr(obj, '_is_task') + return is_task_type(type(obj)) and hasattr(obj, '_is_task') class Storage(ABC):