Skip to content

Commit beaa046

Browse files
committed
improve testing of setup_exists
1 parent 9a9917a commit beaa046

3 files changed

Lines changed: 62 additions & 42 deletions

File tree

openml/runs/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88

99
import numpy as np
10-
import sklearn
10+
import sklearn.pipeline
1111
import six
1212
import xmltodict
1313

openml/setups/functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def setup_exists(flow, model=None):
3434
if model is None:
3535
model = flow.model
3636
else:
37-
converted_flow = sklearn_to_flow(model)
38-
exists = flow_exists(converted_flow.name,
39-
converted_flow.external_version)
37+
exists = flow_exists(flow.name, flow.external_version)
4038
if exists != flow.flow_id:
4139
raise ValueError('This should not happen!')
4240

tests/test_setups/test_setup_functions.py

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
from sklearn.ensemble import BaggingClassifier
1010
from sklearn.tree import DecisionTreeClassifier
11-
12-
if sys.version_info[0] >= 3:
13-
from unittest import mock
14-
else:
15-
import mock
11+
from sklearn.linear_model import LogisticRegression
12+
from sklearn.naive_bayes import GaussianNB
13+
from sklearn.base import BaseEstimator, ClassifierMixin
1614

1715

1816
def get_sentinel():
@@ -26,6 +24,29 @@ def get_sentinel():
2624
return sentinel
2725

2826

27+
class ParameterFreeClassifier(BaseEstimator, ClassifierMixin):
28+
def __init__(self):
29+
self.estimator = None
30+
31+
def fit(self, X, y):
32+
self.estimator = DecisionTreeClassifier()
33+
self.estimator.fit(X, y)
34+
self.classes_ = self.estimator.classes_
35+
return self
36+
37+
def predict(self, X):
38+
return self.estimator.predict(X)
39+
40+
def predict_proba(self, X):
41+
return self.estimator.predict_proba(X)
42+
43+
def set_params(self, **params):
44+
pass
45+
46+
def get_params(self, deep=True):
47+
return {}
48+
49+
2950

3051
class TestRun(TestBase):
3152

@@ -45,39 +66,40 @@ def test_nonexisting_setup_exists(self):
4566
self.assertFalse(setup_id)
4667

4768
def test_existing_setup_exists(self):
48-
# first publish a nonexiting flow
49-
50-
# because of the sentinel, we can not use flows that contain subflows
51-
classif = DecisionTreeClassifier(max_depth=5,
52-
min_samples_split=3,
53-
# Not setting the random state will
54-
# make this flow fail as running it
55-
# will add a random random_state.
56-
random_state=1)
57-
flow = openml.flows.sklearn_to_flow(classif)
58-
flow.name = 'TEST%s%s' % (get_sentinel(), flow.name)
59-
60-
# Replace the flow by a flow in which the ID got set up correctly
61-
flow = flow.publish()
62-
flow = openml.flows.get_flow(flow.flow_id)
63-
64-
# although the flow exists, we can be sure there are no
65-
# setups (yet) as it hasn't been ran
66-
setup_id = openml.setups.setup_exists(flow)
67-
self.assertFalse(setup_id)
68-
69-
# now run the flow on an easy task:
70-
task = openml.tasks.get_task(115) # diabetes
71-
run = openml.runs.run_flow_on_task(task, flow)
72-
# spoof flow id, otherwise the sentinel is ignored
73-
run.flow_id = flow.flow_id
74-
run = run.publish()
75-
# download the run, as it contains the right setup id
76-
run = openml.runs.get_run(run.run_id)
77-
78-
# execute the function we are interested in
79-
setup_id = openml.setups.setup_exists(flow)
80-
self.assertEquals(setup_id, run.setup_id)
69+
clfs = [ParameterFreeClassifier(), # zero hyperparemeters
70+
GaussianNB(), # one hyperparameter
71+
DecisionTreeClassifier(max_depth=5, # many hyperparameters
72+
min_samples_split=3,
73+
# Not setting the random state will
74+
# make this flow fail as running it
75+
# will add a random random_state.
76+
random_state=1)]
77+
78+
for classif in clfs:
79+
# first publish a nonexiting flow
80+
flow = openml.flows.sklearn_to_flow(classif)
81+
flow.name = 'TEST%s%s' % (get_sentinel(), flow.name)
82+
flow.publish()
83+
84+
# although the flow exists, we can be sure there are no
85+
# setups (yet) as it hasn't been ran
86+
setup_id = openml.setups.setup_exists(flow)
87+
self.assertFalse(setup_id)
88+
setup_id = openml.setups.setup_exists(flow, classif)
89+
self.assertFalse(setup_id)
90+
91+
# now run the flow on an easy task:
92+
task = openml.tasks.get_task(115) # diabetes
93+
run = openml.runs.run_flow_on_task(task, flow)
94+
# spoof flow id, otherwise the sentinel is ignored
95+
run.flow_id = flow.flow_id
96+
run.publish()
97+
# download the run, as it contains the right setup id
98+
run = openml.runs.get_run(run.run_id)
99+
100+
# execute the function we are interested in
101+
setup_id = openml.setups.setup_exists(flow)
102+
self.assertEquals(setup_id, run.setup_id)
81103

82104
def test_get_setup(self):
83105
# no setups in default test server

0 commit comments

Comments
 (0)