-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add template + testing setup for pelt
- Loading branch information
Showing
3 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
"""Extension template for series annotation. | ||
Purpose of this implementation template: | ||
quick implementation of new estimators following the template | ||
NOT a concrete class to import! This is NOT a base class or concrete class! | ||
This is to be used as a "fill-in" coding template. | ||
How to use this implementation template to implement a new estimator: | ||
- make a copy of the template in a suitable location, give it a descriptive name. | ||
- work through all the "todo" comments below | ||
- fill in code for mandatory methods, and optionally for optional methods | ||
- you can add more private methods, but do not override BaseEstimator's private methods | ||
an easy way to be safe is to prefix your methods with "_custom" | ||
- change docstrings for functions and the file | ||
- ensure interface compatibility by sktime.utils.estimator_checks.check_estimator | ||
- once complete: use as a local library, or contribute to sktime via PR | ||
- more details: | ||
https://www.sktime.net/en/stable/developer_guide/add_estimators.html | ||
Mandatory implements: | ||
fitting - _fit(self, X, Y=None) | ||
annotating - _predict(self, X) | ||
Optional implements: | ||
updating - _update(self, X, Y=None) | ||
Testing - required for sktime test framework and check_estimator usage: | ||
get default parameters for test instance(s) - get_test_params() | ||
copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
""" | ||
|
||
from sktime.annotation.base import BaseSeriesAnnotator | ||
|
||
# todo: add any necessary imports here | ||
|
||
|
||
class Pelt(BaseSeriesAnnotator): | ||
"""Custom series annotator. | ||
todo: write docstring, describing your custom forecaster | ||
Parameters | ||
---------- | ||
fmt : str {"dense", "sparse"}, optional (default="dense") | ||
Annotation output format: | ||
* If "sparse", a sub-series of labels for only the outliers in X is returned, | ||
* If "dense", a series of labels for all values in X is returned. | ||
labels : str {"indicator", "score"}, optional (default="indicator") | ||
Annotation output labels: | ||
* If "indicator", returned values are boolean, indicating whether a value is an | ||
outlier, | ||
* If "score", returned values are floats, giving the outlier score. | ||
parama : int | ||
descriptive explanation of parama | ||
paramb : string, optional (default='default') | ||
descriptive explanation of paramb | ||
paramc : boolean, optional (default= whether paramb is not the default) | ||
descriptive explanation of paramc | ||
and so on | ||
Components | ||
---------- | ||
est : sktime.estimator, BaseEstimator descendant | ||
descriptive explanation of est | ||
est2: another estimator | ||
descriptive explanation of est2 | ||
and so on | ||
""" | ||
|
||
# todo: add any hyper-parameters and components to constructor | ||
def __init__( | ||
self, | ||
est, | ||
parama, | ||
est2=None, | ||
paramb="default", | ||
paramc=None, | ||
fmt="dense", | ||
labels="indicator", | ||
): | ||
# estimators should precede parameters | ||
# if estimators have default values, set None and initialize below | ||
|
||
# todo: write any hyper-parameters and components to self | ||
self.est = est | ||
self.parama = parama | ||
self.paramb = paramb | ||
self.paramc = paramc | ||
|
||
# leave this as is | ||
super().__init__(fmt=fmt, labels=labels) | ||
|
||
# todo: optional, parameter checking logic (if applicable) should happen here | ||
# if writes derived values to self, should *not* overwrite self.parama etc | ||
# instead, write to self._parama, self._newparam (starting with _) | ||
|
||
# todo: default estimators should have None arg defaults | ||
# and be initialized here | ||
# do this only with default estimators, not with parameters | ||
# if est2 is None: | ||
# self.estimator = MyDefaultEstimator() | ||
|
||
# todo: if tags of estimator depend on component tags, set these here | ||
# only needed if estimator is a composite | ||
# tags set in the constructor apply to the object and override the class | ||
# | ||
# example 1: conditional setting of a tag | ||
# if est.foo == 42: | ||
# self.set_tags(handles-missing-data=True) | ||
# example 2: cloning tags from component | ||
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"]) | ||
|
||
# todo: implement this, mandatory | ||
def _fit(self, X, Y=None): | ||
"""Fit to training data. | ||
core logic | ||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
training data to fit model to, time series | ||
Y : pd.Series, optional | ||
ground truth annotations for training if annotator is supervised | ||
Returns | ||
------- | ||
self : returns a reference to self | ||
State change | ||
------------ | ||
creates fitted model (attributes ending in "_") | ||
""" | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to y, X, fh | ||
|
||
# todo: implement this, mandatory | ||
def _predict(self, X): | ||
"""Create annotations on test/deployment data. | ||
core logic | ||
Parameters | ||
---------- | ||
X : pd.DataFrame - data to annotate, time series | ||
Returns | ||
------- | ||
Y : pd.Series - annotations for sequence X | ||
exact format depends on annotation type | ||
""" | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to X, fh | ||
|
||
# todo: consider implementing this, optional | ||
# if not implementing, delete the _update method | ||
def _update(self, X, Y=None): | ||
"""Update model with new data and optional ground truth annotations. | ||
core logic | ||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
training data to update model with, time series | ||
Y : pd.Series, optional | ||
ground truth annotations for training if annotator is supervised | ||
Returns | ||
------- | ||
self : returns a reference to self | ||
State change | ||
------------ | ||
updates fitted model (attributes ending in "_") | ||
""" | ||
|
||
# implement here | ||
# IMPORTANT: avoid side effects to X, fh | ||
|
||
# todo: return default parameters, so that a test instance can be created | ||
# required for automated unit and integration testing of estimator | ||
@classmethod | ||
def get_test_params(cls, parameter_set="default"): | ||
"""Return testing parameter settings for the estimator. | ||
Parameters | ||
---------- | ||
parameter_set : str, default="default" | ||
Name of the set of test parameters to return, for use in tests. If no | ||
special parameters are defined for a value, will return `"default"` set. | ||
There are currently no reserved values for annotators. | ||
Returns | ||
------- | ||
params : dict or list of dict, default = {} | ||
Parameters to create testing instances of the class | ||
Each dict are parameters to construct an "interesting" test instance, i.e., | ||
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. | ||
`create_test_instance` uses the first (or only) dictionary in `params` | ||
""" | ||
|
||
# todo: set the testing parameters for the estimators | ||
# Testing parameters can be dictionary or list of dictionaries | ||
# Testing parameter choice should cover internal cases well. | ||
# | ||
# this method can, if required, use: | ||
# class properties (e.g., inherited); parent class test case | ||
# imported objects such as estimators from sktime or sklearn | ||
# important: all such imports should be *inside get_test_params*, not at the top | ||
# since imports are used only at testing time | ||
# | ||
# The parameter_set argument is not used for automated, module level tests. | ||
# It can be used in custom, estimator specific tests, for "special" settings. | ||
# A parameter dictionary must be returned *for all values* of parameter_set, | ||
# i.e., "parameter_set not available" errors should never be raised. | ||
# | ||
# A good parameter set should primarily satisfy two criteria, | ||
# 1. Chosen set of parameters should have a low testing time, | ||
# ideally in the magnitude of few seconds for the entire test suite. | ||
# This is vital for the cases where default values result in | ||
# "big" models which not only increases test time but also | ||
# run into the risk of test workers crashing. | ||
# 2. There should be a minimum two such parameter sets with different | ||
# sets of values to ensure a wide range of code coverage is provided. | ||
# | ||
# example 1: specify params as dictionary | ||
# any number of params can be specified | ||
# params = {"est": value0, "parama": value1, "paramb": value2} | ||
# | ||
# example 2: specify params as list of dictionary | ||
# note: Only first dictionary will be used by create_test_instance | ||
# params = [{"est": value1, "parama": value2}, | ||
# {"est": value3, "parama": value4}] | ||
# return params | ||
# | ||
# example 3: parameter set depending on param_set value | ||
# note: only needed if a separate parameter set is needed in tests | ||
# if parameter_set == "special_param_set": | ||
# params = {"est": value1, "parama": value2} | ||
# return params | ||
# | ||
# # "default" params - always returned except for "special_param_set" value | ||
# params = {"est": value3, "parama": value4} | ||
# return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Simple PELT test.""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from sktime.annotation.clasp import ClaSPSegmentation | ||
from sktime.datasets import load_gun_point_segmentation | ||
from sktime.tests.test_switch import run_test_for_class | ||
from sktime.utils._testing.annotation import make_annotation_problem | ||
|
||
from skchange.change_detectors.pelt import Pelt | ||
|
||
|
||
def test_output_type(): | ||
"""Test annotator output type.""" | ||
Estimator = Pelt | ||
estimator = Estimator.create_test_instance() | ||
if not run_test_for_class(Estimator): | ||
return None | ||
|
||
arg = make_annotation_problem( | ||
n_timepoints=50, estimator_type=estimator.get_tag("distribution_type") | ||
) | ||
estimator.fit(arg) | ||
arg = make_annotation_problem( | ||
n_timepoints=10, estimator_type=estimator.get_tag("distribution_type") | ||
) | ||
y_pred = estimator.predict(arg) | ||
assert isinstance(y_pred, (pd.Series, np.ndarray)) | ||
|
||
|
||
def test_pelt_sparse(): | ||
"""Test ClaSP sparse segmentation. | ||
Check if the predicted change points match. | ||
""" | ||
# load the test dataset | ||
ts, period_size, cps = load_gun_point_segmentation() | ||
|
||
# compute a ClaSP segmentation | ||
clasp = ClaSPSegmentation(period_size, n_cps=1) | ||
clasp.fit(ts) | ||
found_cps = clasp.predict(ts) | ||
scores = clasp.predict_scores(ts) | ||
|
||
assert len(found_cps) == 1 and found_cps[0] == 893 | ||
assert len(scores) == 1 and scores[0] > 0.74 | ||
|
||
|
||
def test_pelt_dense(): | ||
"""Tests ClaSP dense segmentation. | ||
Check if the predicted segmentation matches. | ||
""" | ||
# load the test dataset | ||
ts, period_size, cps = load_gun_point_segmentation() | ||
|
||
# compute a ClaSP segmentation | ||
clasp = ClaSPSegmentation(period_size, n_cps=1, fmt="dense") | ||
clasp.fit(ts) | ||
segmentation = clasp.predict(ts) | ||
scores = clasp.predict_scores(ts) | ||
|
||
assert len(segmentation) == 2 and segmentation[0].right == 893 | ||
assert np.argmax(scores) == 893 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from sktime.utils.estimator_checks import check_estimator | ||
|
||
# from sktime.registry import all_estimators | ||
from sktime.tests.test_switch import run_test_for_class | ||
from sktime.utils.estimator_checks import check_estimator | ||
|
||
from skchange.change_detectors.pelt import Pelt | ||
|
||
# ALL_ANNOTATORS = all_estimators(estimator_types="series-annotator", return_names=False) | ||
ALL_ANNOTATORS = [Pelt] | ||
|
||
|
||
@pytest.mark.parametrize("Estimator", ALL_ANNOTATORS) | ||
def test_output_type(Estimator): | ||
"""Test annotator output type.""" | ||
if not run_test_for_class(Estimator): | ||
return None | ||
|
||
check_estimator(Estimator, raise_exceptions=True) |