Skip to content

Commit 32b2df3

Browse files
authored
Merge pull request #340 from openml/fix_218
Add unit test for loading non-sklearn flows, fixes #218
2 parents 7210b3e + f5289fb commit 32b2df3

2 files changed

Lines changed: 38 additions & 4 deletions

File tree

openml/runs/functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -677,10 +677,18 @@ def _create_run_from_xml(xml):
677677
'description XML' % run_id)
678678

679679
if 'predictions' not in files:
680-
# JvR: actually, I am not sure whether this error should be raised.
681-
# a run can consist without predictions. But for now let's keep it
682-
raise ValueError('No prediction files for run %d in run '
683-
'description XML' % run_id)
680+
task = openml.tasks.get_task(task_id)
681+
if task.task_type_id == 8:
682+
raise NotImplementedError(
683+
'Subgroup discovery tasks are not yet supported.'
684+
)
685+
else:
686+
# JvR: actually, I am not sure whether this error should be raised.
687+
# a run can consist without predictions. But for now let's keep it
688+
# Matthias: yes, it should stay as long as we do not really handle
689+
# this stuff
690+
raise ValueError('No prediction files for run %d in run '
691+
'description XML' % run_id)
684692

685693
tags = openml.utils.extract_xml_tags('oml:tag', run)
686694

tests/test_flows/test_flow.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,29 @@ def test_extract_tags(self):
348348
flow_dict = xmltodict.parse(flow_xml)
349349
tags = openml.utils.extract_xml_tags('oml:tag', flow_dict['oml:flow'])
350350
self.assertEqual(tags, ['OpenmlWeka', 'weka'])
351+
352+
def test_download_non_scikit_learn_flows(self):
353+
openml.config.server = self.production_server
354+
355+
flow = openml.flows.get_flow(6742)
356+
self.assertIsInstance(flow, openml.OpenMLFlow)
357+
self.assertEqual(flow.flow_id, 6742)
358+
self.assertEqual(len(flow.parameters), 19)
359+
self.assertEqual(len(flow.components), 1)
360+
self.assertIsNone(flow.model)
361+
362+
subflow_1 = list(flow.components.values())[0]
363+
self.assertIsInstance(subflow_1, openml.OpenMLFlow)
364+
self.assertEqual(subflow_1.flow_id, 6743)
365+
self.assertEqual(len(subflow_1.parameters), 8)
366+
self.assertEqual(subflow_1.parameters['U'], '0')
367+
self.assertEqual(len(subflow_1.components), 1)
368+
self.assertIsNone(subflow_1.model)
369+
370+
subflow_2 = list(subflow_1.components.values())[0]
371+
self.assertIsInstance(subflow_2, openml.OpenMLFlow)
372+
self.assertEqual(subflow_2.flow_id, 5888)
373+
self.assertEqual(len(subflow_2.parameters), 4)
374+
self.assertIsNone(subflow_2.parameters['batch-size'])
375+
self.assertEqual(len(subflow_2.components), 0)
376+
self.assertIsNone(subflow_2.model)

0 commit comments

Comments
 (0)