Skip to content

Commit 41a4a74

Browse files
author
janvanrijn
committed
fix 136, return task list as dictionary
1 parent 10fa379 commit 41a4a74

6 files changed

Lines changed: 79 additions & 91 deletions

File tree

openml/tasks/functions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,15 @@ def _list_tasks(api_call):
145145
'"oml:runs"/@xmlns:oml is not '
146146
'"http://openml.org/openml": %s'
147147
% str(tasks_dict))
148+
148149
try:
149-
tasks = []
150+
tasks = dict();
150151
procs = _get_estimation_procedure_list()
151152
proc_dict = dict((x['id'], x) for x in procs)
152153
for task_ in tasks_dict['oml:tasks']['oml:task']:
153-
task = {'tid': int(task_['oml:task_id']),
154+
tid = int(task_['oml:task_id'])
155+
task = {'tid': tid,
156+
'ttid': int(task_['oml:task_type_id']),
154157
'did': int(task_['oml:did']),
155158
'name': task_['oml:name'],
156159
'task_type': task_['oml:task_type'],
@@ -170,12 +173,10 @@ def _list_tasks(api_call):
170173
if abs(int(quality['#text']) - quality['#text']) < 0.0000001:
171174
quality['#text'] = int(quality['#text'])
172175
task[quality['@name']] = quality['#text']
173-
tasks.append(task)
176+
tasks[tid] = task
174177
except KeyError as e:
175178
raise KeyError("Invalid xml for task: %s" % e)
176179

177-
tasks.sort(key=lambda t: t['tid'])
178-
179180
return tasks
180181

181182

@@ -245,7 +246,7 @@ def _create_task_from_xml(xml):
245246
estimation_parameters[name] = text
246247

247248
return OpenMLTask(
248-
dic["oml:task_id"], dic["oml:task_type"],
249+
dic["oml:task_id"], dic['oml:task_type_id'], dic["oml:task_type"],
249250
inputs["source_data"]["oml:data_set"]["oml:data_set_id"],
250251
inputs["source_data"]["oml:data_set"]["oml:target_feature"],
251252
inputs["estimation_procedure"]["oml:estimation_procedure"][

openml/tasks/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010

1111
class OpenMLTask(object):
12-
def __init__(self, task_id, task_type, data_set_id, target_name,
13-
estimation_procedure_type, data_splits_url,
12+
def __init__(self, task_id, task_type_id, task_type, data_set_id,
13+
target_name, estimation_procedure_type, data_splits_url,
1414
estimation_parameters, evaluation_measure, cost_matrix,
1515
class_labels=None):
1616
self.task_id = int(task_id)

tests/files/tasks/1/task.xml

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
<oml:task xmlns:oml="http://openml.org/openml">
22
<oml:task_id>1</oml:task_id>
3+
<oml:task_type_id>1</oml:task_type_id>
34
<oml:task_type>Supervised Classification</oml:task_type>
45
<oml:input name="source_data">
5-
<oml:data_set>
6-
<oml:data_set_id>1</oml:data_set_id>
7-
<oml:target_feature>class</oml:target_feature>
6+
<oml:data_set>
7+
<oml:data_set_id>1</oml:data_set_id>
8+
<oml:target_feature>class</oml:target_feature>
89
</oml:data_set> </oml:input>
910
<oml:input name="estimation_procedure">
10-
<oml:estimation_procedure>
11-
<oml:type>crossvalidation</oml:type>
12-
<oml:data_splits_url>http://www.openml.org/api_splits/get/1/Task_1_splits.arff</oml:data_splits_url>
13-
<oml:parameter name="number_repeats">1</oml:parameter>
14-
<oml:parameter name="number_folds">10</oml:parameter>
15-
<oml:parameter name="percentage"></oml:parameter>
16-
<oml:parameter name="stratified_sampling">true</oml:parameter>
11+
<oml:estimation_procedure>
12+
<oml:type>crossvalidation</oml:type>
13+
<oml:data_splits_url>http://www.openml.org/api_splits/get/1/Task_1_splits.arff</oml:data_splits_url>
14+
<oml:parameter name="number_repeats">1</oml:parameter>
15+
<oml:parameter name="number_folds">10</oml:parameter>
16+
<oml:parameter name="percentage"></oml:parameter>
17+
<oml:parameter name="stratified_sampling">true</oml:parameter>
1718
</oml:estimation_procedure> </oml:input>
1819
<oml:input name="cost_matrix">
1920
<oml:cost_matrix></oml:cost_matrix> </oml:input>
2021
<oml:input name="evaluation_measures">
21-
<oml:evaluation_measures>
22-
<oml:evaluation_measure>predictive_accuracy</oml:evaluation_measure>
22+
<oml:evaluation_measures>
23+
<oml:evaluation_measure>predictive_accuracy</oml:evaluation_measure>
2324
</oml:evaluation_measures> </oml:input>
2425
<oml:output name="predictions">
25-
<oml:predictions>
26-
<oml:format>ARFF</oml:format>
27-
<oml:feature name="repeat" type="integer"/>
28-
<oml:feature name="fold" type="integer"/>
29-
<oml:feature name="row_id" type="integer"/>
30-
<oml:feature name="confidence.classname" type="numeric"/>
31-
<oml:feature name="prediction" type="string"/>
26+
<oml:predictions>
27+
<oml:format>ARFF</oml:format>
28+
<oml:feature name="repeat" type="integer"/>
29+
<oml:feature name="fold" type="integer"/>
30+
<oml:feature name="row_id" type="integer"/>
31+
<oml:feature name="confidence.classname" type="numeric"/>
32+
<oml:feature name="prediction" type="string"/>
3233
</oml:predictions> </oml:output>
3334
<oml:tag>basic</oml:tag>
3435
<oml:tag>study_1</oml:tag>

tests/files/tasks/1882/task.xml

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
<oml:task xmlns:oml="http://openml.org/openml">
22
<oml:task_id>1882</oml:task_id>
3+
<oml:task_type_id>1</oml:task_type_id>
34
<oml:task_type>Supervised Classification</oml:task_type>
45
<oml:input name="source_data">
5-
<oml:data_set>
6-
<oml:data_set_id>2</oml:data_set_id>
7-
<oml:target_feature>class</oml:target_feature>
6+
<oml:data_set>
7+
<oml:data_set_id>2</oml:data_set_id>
8+
<oml:target_feature>class</oml:target_feature>
89
</oml:data_set> </oml:input>
910
<oml:input name="estimation_procedure">
10-
<oml:estimation_procedure>
11-
<oml:type>crossvalidation</oml:type>
12-
<oml:data_splits_url>http://capa.win.tue.nl/api_splits/get/1882/Task_1882_splits.arff</oml:data_splits_url>
13-
<oml:parameter name="number_repeats">10</oml:parameter>
14-
<oml:parameter name="number_folds">10</oml:parameter>
15-
<oml:parameter name="percentage"></oml:parameter>
16-
<oml:parameter name="stratified_sampling">true</oml:parameter>
11+
<oml:estimation_procedure>
12+
<oml:type>crossvalidation</oml:type>
13+
<oml:data_splits_url>http://capa.win.tue.nl/api_splits/get/1882/Task_1882_splits.arff</oml:data_splits_url>
14+
<oml:parameter name="number_repeats">10</oml:parameter>
15+
<oml:parameter name="number_folds">10</oml:parameter>
16+
<oml:parameter name="percentage"></oml:parameter>
17+
<oml:parameter name="stratified_sampling">true</oml:parameter>
1718
</oml:estimation_procedure> </oml:input>
1819
<oml:input name="cost_matrix">
1920
<oml:cost_matrix></oml:cost_matrix> </oml:input>
2021
<oml:input name="evaluation_measures">
21-
<oml:evaluation_measures>
22-
<oml:evaluation_measure>predictive_accuracy</oml:evaluation_measure>
22+
<oml:evaluation_measures>
23+
<oml:evaluation_measure>predictive_accuracy</oml:evaluation_measure>
2324
</oml:evaluation_measures> </oml:input>
2425
<oml:output name="predictions">
25-
<oml:predictions>
26-
<oml:format>ARFF</oml:format>
27-
<oml:feature name="repeat" type="integer"/>
28-
<oml:feature name="fold" type="integer"/>
29-
<oml:feature name="row_id" type="integer"/>
30-
<oml:feature name="confidence.classname" type="numeric"/>
31-
<oml:feature name="prediction" type="string"/>
26+
<oml:predictions>
27+
<oml:format>ARFF</oml:format>
28+
<oml:feature name="repeat" type="integer"/>
29+
<oml:feature name="fold" type="integer"/>
30+
<oml:feature name="row_id" type="integer"/>
31+
<oml:feature name="confidence.classname" type="numeric"/>
32+
<oml:feature name="prediction" type="string"/>
3233
</oml:predictions> </oml:output>
3334
<oml:tag>under100k</oml:tag>
3435
<oml:tag>under1m</oml:tag>

tests/files/tasks/3/task.xml

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
<oml:task xmlns:oml="http://openml.org/openml">
22
<oml:task_id>3</oml:task_id>
3+
<oml:task_type_id>1</oml:task_type_id>
34
<oml:task_type>Supervised Classification</oml:task_type>
45
<oml:input name="source_data">
5-
<oml:data_set>
6-
<oml:data_set_id>3</oml:data_set_id>
7-
<oml:target_feature>class</oml:target_feature>
6+
<oml:data_set>
7+
<oml:data_set_id>3</oml:data_set_id>
8+
<oml:target_feature>class</oml:target_feature>
89
</oml:data_set> </oml:input>
910
<oml:input name="estimation_procedure">
10-
<oml:estimation_procedure>
11-
<oml:type>crossvalidation</oml:type>
12-
<oml:data_splits_url>http://www.openml.org/api_splits/get/3/Task_3_splits.arff</oml:data_splits_url>
13-
<oml:parameter name="number_repeats">1</oml:parameter>
14-
<oml:parameter name="number_folds">10</oml:parameter>
15-
<oml:parameter name="percentage"></oml:parameter>
16-
<oml:parameter name="stratified_sampling">true</oml:parameter>
11+
<oml:estimation_procedure>
12+
<oml:type>crossvalidation</oml:type>
13+
<oml:data_splits_url>http://www.openml.org/api_splits/get/3/Task_3_splits.arff</oml:data_splits_url>
14+
<oml:parameter name="number_repeats">1</oml:parameter>
15+
<oml:parameter name="number_folds">10</oml:parameter>
16+
<oml:parameter name="percentage"></oml:parameter>
17+
<oml:parameter name="stratified_sampling">true</oml:parameter>
1718
</oml:estimation_procedure> </oml:input>
1819
<oml:input name="cost_matrix">
1920
<oml:cost_matrix></oml:cost_matrix> </oml:input>
2021
<oml:input name="evaluation_measures">
21-
<oml:evaluation_measures>
22-
<oml:evaluation_measure>predictive_accuracy</oml:evaluation_measure>
22+
<oml:evaluation_measures>
23+
<oml:evaluation_measure>predictive_accuracy</oml:evaluation_measure>
2324
</oml:evaluation_measures> </oml:input>
2425
<oml:output name="predictions">
25-
<oml:predictions>
26-
<oml:format>ARFF</oml:format>
27-
<oml:feature name="repeat" type="integer"/>
28-
<oml:feature name="fold" type="integer"/>
29-
<oml:feature name="row_id" type="integer"/>
30-
<oml:feature name="confidence.classname" type="numeric"/>
31-
<oml:feature name="prediction" type="string"/>
26+
<oml:predictions>
27+
<oml:format>ARFF</oml:format>
28+
<oml:feature name="repeat" type="integer"/>
29+
<oml:feature name="fold" type="integer"/>
30+
<oml:feature name="row_id" type="integer"/>
31+
<oml:feature name="confidence.classname" type="numeric"/>
32+
<oml:feature name="prediction" type="string"/>
3233
</oml:predictions> </oml:output>
3334
<oml:tag>basic</oml:tag>
3435
<oml:tag>mythbusting</oml:tag>

tests/tasks/test_task_functions.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,37 +53,29 @@ def _check_task(self, task):
5353
def test_list_tasks_by_type(self):
5454
tasks = openml.tasks.list_tasks(task_type_id=3)
5555
self.assertGreaterEqual(len(tasks), 300)
56-
for task in tasks:
57-
self._check_task(task)
56+
for tid in tasks:
57+
self._check_task(tasks[tid])
5858

5959
def test_list_tasks_by_tag(self):
6060
tasks = openml.tasks.list_tasks(tag='basic')
6161
self.assertGreaterEqual(len(tasks), 57)
62-
for task in tasks:
63-
self._check_task(task)
62+
for tid in tasks:
63+
self._check_task(tasks[tid])
6464

6565
def test_list_tasks(self):
6666
tasks = openml.tasks.list_tasks()
6767
self.assertGreaterEqual(len(tasks), 2000)
68-
for task in tasks:
69-
self._check_task(task)
68+
for tid in tasks:
69+
self._check_task(tasks[tid])
7070

7171
def test_list_tasks_paginate(self):
7272
size = 10
7373
max = 100
7474
for i in range(0, max, size):
7575
tasks = openml.tasks.list_tasks(offset=i, size=size)
7676
self.assertGreaterEqual(size, len(tasks))
77-
for task in tasks:
78-
self.assertEqual(type(task), dict)
79-
self.assertGreaterEqual(len(task), 4)
80-
self.assertIn('tid', task)
81-
self.assertIsInstance(task['tid'], int)
82-
self.assertIn('did', task)
83-
self.assertIsInstance(task['did'], int)
84-
self.assertIn('status', task)
85-
self.assertTrue(is_string(task['status']))
86-
self.assertIn(task['status'], ['in_preparation', 'active', 'deactivated'])
77+
for tid in tasks:
78+
self._check_task(tasks[tid])
8779

8880
def test_list_tasks_per_type_paginate(self):
8981
size = 10
@@ -93,16 +85,8 @@ def test_list_tasks_per_type_paginate(self):
9385
for i in range(0, max, size):
9486
tasks = openml.tasks.list_tasks(task_type_id=j, offset=i, size=size)
9587
self.assertGreaterEqual(size, len(tasks))
96-
for task in tasks:
97-
self.assertEqual(type(task), dict)
98-
self.assertGreaterEqual(len(task), 4)
99-
self.assertIn('tid', task)
100-
self.assertIsInstance(task['tid'], int)
101-
self.assertIn('did', task)
102-
self.assertIsInstance(task['did'], int)
103-
self.assertIn('status', task)
104-
self.assertTrue(is_string(task['status']))
105-
self.assertIn(task['status'], ['in_preparation', 'active', 'deactivated'])
88+
for tid in tasks:
89+
self._check_task(tasks[tid])
10690

10791
def test__get_task(self):
10892
openml.config.set_cache_directory(self.static_cache_dir)

0 commit comments

Comments
 (0)