Skip to content

Commit 4e70bd4

Browse files
committed
TST add estimator test for Delayer
1 parent 8986421 commit 4e70bd4

2 files changed

Lines changed: 18 additions & 4 deletions

File tree

voxelwise_tutorials/delayer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from sklearn.base import BaseEstimator, TransformerMixin
3+
from sklearn.utils.validation import check_is_fitted, check_array
34

45

56
class Delayer(BaseEstimator, TransformerMixin):
@@ -26,19 +27,23 @@ class Delayer(BaseEstimator, TransformerMixin):
2627
Example
2728
-------
2829
>>> from sklearn.pipeline import make_pipeline
29-
>>> from voxelwise.delayer import Delayer
30+
>>> from voxelwise_tutorials.delayer import Delayer
3031
>>> from himalaya.kernel_ridge import KernelRidgeCV
31-
>>> pipeline = make_pipeline(Delayer(delays=[1, 2, 3, 4), KernelRidgeCV())
32-
>>> pipeline.fit(..., ...)
32+
>>> pipeline = make_pipeline(Delayer(delays=[1, 2, 3, 4]), KernelRidgeCV())
3333
"""
34-
def __init__(self, delays=[1, 2, 3, 4]):
34+
35+
def __init__(self, delays=None):
3536
self.delays = delays
3637

3738
def fit(self, X, y=None):
39+
X = self._validate_data(X, dtype='numeric')
3840
self.n_features_in_ = X.shape[1]
3941
return self
4042

4143
def transform(self, X):
44+
check_is_fitted(self)
45+
X = check_array(X, copy=True)
46+
4247
n_samples, n_features = X.shape
4348
if n_features != self.n_features_in_:
4449
raise ValueError(
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import sklearn.kernel_ridge
2+
import sklearn.utils.estimator_checks
3+
4+
from voxelwise_tutorials.delayer import Delayer
5+
6+
7+
@sklearn.utils.estimator_checks.parametrize_with_checks([Delayer()])
8+
def test_check_estimator(estimator, check):
9+
check(estimator)

0 commit comments

Comments
 (0)