Skip to content

Commit

Permalink
Merge pull request #11 from nathanjmcdougall/feature/test-task
Browse files Browse the repository at this point in the history
Add tests for `task`
  • Loading branch information
ben-denham authored Apr 27, 2024
2 parents ef08cfe + 7a7b690 commit f61942d
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 8 deletions.
12 changes: 7 additions & 5 deletions labtech/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ class CacheDefault:

CACHE_DEFAULT = CacheDefault()

_RESERVED_ATTRS = [
'_lt', '_is_task', 'cache_key', 'result', '_results_map', '_set_results_map',
'result_meta', '_set_result_meta', 'context', 'set_context', '__post_init__',
]
"""Reserved attribute names for task types."""


def immutable_param_value(key: str, value: Any) -> Any:
"""Converts a parameter value to an immutable equivalent that is hashable."""
Expand Down Expand Up @@ -187,12 +193,8 @@ def run(self):
def decorator(cls):
nonlocal cache

reserved_attrs = [
'_lt', '_is_task', 'cache_key', 'result', '_results_map', '_set_results_map',
'result_meta', '_set_result_meta', 'context', 'set_context',
]
if not is_task_type(cls):
for reserved_attr in reserved_attrs:
for reserved_attr in _RESERVED_ATTRS:
if hasattr(cls, reserved_attr):
raise AttributeError(f"Task type already defines reserved attribute '{reserved_attr}'.")

Expand Down
4 changes: 2 additions & 2 deletions labtech/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def run(self):
def is_task_type(cls):
"""Returns `True` if the given `cls` is a class decorated with
[`labtech.task`][labtech.task]."""
return isclass(cls) and hasattr(cls, '_lt')
return isclass(cls) and isinstance(getattr(cls, '_lt', None), TaskInfo)


def is_task(obj):
"""Returns `True` if the given `obj` is an instance of a task class."""
return hasattr(obj, '_is_task')
return is_task_type(type(obj)) and hasattr(obj, '_is_task')


class Storage(ABC):
Expand Down
196 changes: 195 additions & 1 deletion tests/labtech/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import re
from dataclasses import FrozenInstanceError
from enum import Enum

import labtech.tasks
import pytest
from frozendict import frozendict
from labtech.cache import BaseCache, NullCache, PickleCache
from labtech.exceptions import TaskError
from labtech.tasks import immutable_param_value
from labtech.tasks import _RESERVED_ATTRS, immutable_param_value
from labtech.types import ResultT, Storage, Task, TaskInfo


class _BadObject:
Expand All @@ -18,6 +22,196 @@ class _ExampleEnum(Enum):
B = 2


class BadCache(BaseCache):
"""A pretend cache that returns fake values."""

KEY_PREFIX = "bad__"

def save_result(self, storage: Storage, task: Task[ResultT], result: ResultT):
raise NotImplementedError

def load_result(self, storage: Storage, task: Task[ResultT]) -> ResultT:
raise NotImplementedError


class TestTask:
def test_defaults(self) -> None:
@labtech.task
class SimpleTask:
def run(self) -> None:
return None

task = SimpleTask()
task_info: TaskInfo = task._lt

assert isinstance(task_info.cache, PickleCache)
assert task_info.max_parallel is None
assert task_info.mlflow_run is False
assert task_info.orig_post_init is None

def test_null_cache(self) -> None:
@labtech.task(cache=None)
class SimpleTask:
def run(self) -> None:
return None

task = SimpleTask()
task_info: TaskInfo = task._lt
assert isinstance(task_info.cache, NullCache)

def test_nondefault_cache(self) -> None:
@labtech.task(cache=BadCache())
class SimpleTask:
def run(self) -> None:
return None

task = SimpleTask()
task_info: TaskInfo = task._lt
assert isinstance(task_info.cache, BadCache)

@pytest.mark.parametrize("max_parallel", [None, 1, 2, 3])
def test_max_parallel(self, max_parallel: int | None) -> None:
@labtech.task(max_parallel=max_parallel)
class SimpleTask:
def run(self) -> None:
pass

task = SimpleTask()
task_info: TaskInfo = task._lt
assert task_info.max_parallel == max_parallel

def test_mlflow_run(self) -> None:
@labtech.task(mlflow_run=True)
class SimpleTask:
def run(self) -> None:
pass

task = SimpleTask()
task_info: TaskInfo = task._lt
assert task_info.mlflow_run is True

def test_reserved_lt_attr(self) -> None:
match = re.escape("Task type already defines reserved attribute '_lt'.")
with pytest.raises(AttributeError, match=match):

@labtech.task
class SimpleTask:
def run(self) -> None:
pass

def _lt(self) -> None:
pass

def _is_task(self) -> None:
pass

@pytest.mark.parametrize("badattr", _RESERVED_ATTRS)
def test_fail_reserved_attrs(self, badattr: str) -> None:
class SimpleTaskBase:
pass

setattr(SimpleTaskBase, badattr, None)

match = re.escape(f"Task type already defines reserved attribute '{badattr}'.")
with pytest.raises(AttributeError, match=match):

@labtech.task(mlflow_run=True)
class SimpleTask(SimpleTaskBase):
def run(self) -> None:
pass

def test_stored_post_init(self) -> None:
@labtech.task
class SimpleTask:
def post_init(self):
return "It's me!"

def run(self) -> None:
pass

task = SimpleTask()
task_info: TaskInfo = task._lt
assert task_info.orig_post_init is not None
assert task_info.orig_post_init(task) == "It's me!"

def test_frozen(self) -> None:
@labtech.task
class SimpleTask:
a: int

def run(self) -> None:
return None

task = SimpleTask(a=1)

# Check the dataclass is now frozen
with pytest.raises(FrozenInstanceError):
task.a = 2

def test_order(self) -> None:
@labtech.task
class SimpleTask:
a: int
b: str

def run(self) -> None:
return None

task1 = SimpleTask(a=1, b="hello")
task2 = SimpleTask(b="hello", a=2)
task3 = SimpleTask(a=1, b="zzz")

assert task1 < task2
assert task2 > task3
assert task1 <= task3
assert task1 != task2
assert task1 == task1

def test_fail_no_run(self) -> None:
match = re.escape("Task type 'SimpleTask' must define a 'run' method")
with pytest.raises(NotImplementedError, match=match):

@labtech.task
class SimpleTask:
pass

def test_fail_noncallable_run(self) -> None:
match = re.escape("Task type 'SimpleTask' must define a 'run' method")
with pytest.raises(NotImplementedError, match=match):

@labtech.task
class SimpleTask:
run: int

def test_post_init_missing_dunder(self) -> None:
match = re.escape(
"Task type already defines reserved attribute '__post_init__'."
)
with pytest.raises(AttributeError, match=match):

@labtech.task
class SimpleTask:
def __post_init__(self):
pass

def run(self) -> None:
pass

def test_inheritance(self) -> None:
@labtech.task
class SimpleTask:
def run(self) -> None:
pass

@labtech.task
class SubTask(SimpleTask):
def run(self) -> None:
pass

# Check we don't get an error trying to do this.
SubTask()


class TestImmutableParamValue:
def test_empty_list(self) -> None:
assert immutable_param_value("hello", []) == ()
Expand Down

0 comments on commit f61942d

Please sign in to comment.