Skip to content

Commit 08b4228

Browse files
committed
Feature: download flow list
1 parent 1f9bbb0 commit 08b4228

3 files changed

Lines changed: 60 additions & 10 deletions

File tree

openml/apiconnector.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ def get_runs_list(self, task_id=None, flow_id=None, setup_id=None):
832832
Returns
833833
-------
834834
list
835-
A list of all runs run IDs for a given ID.
835+
A list of all runs for a given ID.
836836
"""
837837
test = [task_id is None, flow_id is None, setup_id is None]
838838
if np.nansum(test) != 2:
@@ -843,7 +843,7 @@ def get_runs_list(self, task_id=None, flow_id=None, setup_id=None):
843843
if task_id is not None:
844844
call += "?task_id=%d" % task_id
845845
elif flow_id is not None:
846-
call += "?implementation_id=%d" % flow_id
846+
call += "?flow_id=%d" % flow_id
847847
elif setup_id is not None:
848848
call += "?setup_id=%d" % setup_id
849849

@@ -865,8 +865,9 @@ def get_runs_list(self, task_id=None, flow_id=None, setup_id=None):
865865
run = {'run_id': int(runs_['oml:run_id']),
866866
'task_id': int(runs_['oml:task_id']),
867867
'setup_id': int(runs_['oml:setup_id']),
868-
'implementation_id': int(runs_['oml:implementation_id']),
869-
'uploader': int(runs_['oml:uploader'])}
868+
'flow_id': int(runs_['oml:flow_id']),
869+
'uploader': int(runs_['oml:uploader']),
870+
'error_message': runs_['oml:error_message']}
870871

871872
runs.append(run)
872873
runs.sort(key=lambda t: t['run_id'])
@@ -957,10 +958,46 @@ def _create_run_from_xml(self, xml):
957958

958959
return OpenMLRun(
959960
dic[u"oml:run_id"], dic[u"oml:uploader"],
960-
dic[u"oml:task_id"], dic[u"oml:implementation_id"],
961+
dic[u"oml:task_id"], dic[u"oml:flow_id"],
961962
dic[u"oml:setup_string"], dic[u'oml:setup_id'],
962963
tags, datasets, files, evaluations)
963964

965+
############################################################################
966+
# Flows
967+
def get_flow_list(self):
968+
"""Return a list of all flows on OpenML.
969+
970+
Returns
971+
-------
972+
list
973+
A list of all flows.
974+
"""
975+
return_code, xml_string = self._perform_api_call("/flow/list")
976+
datasets_dict = xmltodict.parse(xml_string)
977+
978+
if isinstance(datasets_dict['oml:flows']['oml:flow'], dict):
979+
flows = [datasets_dict['oml:implementations']['oml:implementation']]
980+
else:
981+
# Minimalistic check if the XML is useful
982+
assert type(datasets_dict['oml:flows']['oml:flow']) == list, \
983+
type(datasets_dict['oml:flows']['oml:flow'])
984+
assert datasets_dict['oml:flows']['@xmlns:oml'] == \
985+
'http://openml.org/openml'
986+
987+
flows = []
988+
for flow_ in datasets_dict['oml:flows']['oml:flow']:
989+
flow = {'id': int(flow_['oml:id']),
990+
'full_name': flow_['oml:full_name'],
991+
'name': flow_['oml:name'],
992+
'version': flow_['oml:version'],
993+
'external_version': flow_['oml:external_version'],
994+
'uploader': int(flow_['oml:uploader'])}
995+
996+
flows.append(flow)
997+
flows.sort(key=lambda t: t['id'])
998+
999+
return flows
1000+
9641001
############################################################################
9651002
# Internal stuff
9661003
def _perform_api_call(self, call, data=None, file_path=None):

source/progress.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ API call implemented tested properly test
3232
/task/delete
3333
/tasktype/list
3434
/tasktype/{task_id}
35+
/flow/list yes
3536
/flow/tag
3637
/flow/untag
3738
/flow/{flow_id}

tests/test_apiconnector.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_download_split(self):
226226
def test_download_run_list(self):
227227
def check_run(run):
228228
self.assertIsInstance(run, dict)
229-
self.assertEqual(len(run), 5)
229+
self.assertEqual(len(run), 6)
230230

231231
runs = self.connector.get_runs_list(task_id=1)
232232
# 1759 as the number of supervised classification tasks retrieved
@@ -238,13 +238,13 @@ def check_run(run):
238238

239239
runs = self.connector.get_runs_list(flow_id=1)
240240
self.assertGreaterEqual(len(runs), 1)
241-
for task in runs:
242-
check_run(task)
241+
for run in runs:
242+
check_run(run)
243243

244244
runs = self.connector.get_runs_list(setup_id=1)
245245
self.assertGreaterEqual(len(runs), 261)
246-
for task in runs:
247-
check_run(task)
246+
for run in runs:
247+
check_run(run)
248248

249249
def test_download_run(self):
250250
run = self.connector.download_run(473350)
@@ -254,6 +254,18 @@ def test_download_run(self):
254254
self.assertGreaterEqual(len(run.evaluations), 18)
255255
self.assertEqual(len(run.evaluations['f_measure']), 2)
256256

257+
# ###########################################################################
258+
# Flows
259+
def test_download_flow_list(self):
260+
def check_flow(flow):
261+
self.assertIsInstance(flow, dict)
262+
self.assertEqual(len(flow), 6)
263+
264+
flows = self.connector.get_flow_list()
265+
self.assertGreaterEqual(len(flows), 1448)
266+
for flow in flows:
267+
check_flow(flow)
268+
257269
def test_upload_dataset(self):
258270

259271
dataset = self.connector.download_dataset(3)

0 commit comments

Comments
 (0)