Skip to content

Commit 5405778

Browse files
committed
add sort to list_evaluations
1 parent 3fab583 commit 5405778

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

openml/evaluations/functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def list_evaluations(
1919
uploader: Optional[List] = None,
2020
tag: Optional[str] = None,
2121
per_fold: Optional[bool] = None,
22+
sort: Optional[str] = None,
2223
output_format: str = 'object'
2324
) -> Union[Dict, pd.DataFrame]:
2425
"""
@@ -48,6 +49,9 @@ def list_evaluations(
4849
4950
per_fold : bool, optional
5051
52+
sort : str, optional
53+
order of sorting evaluations, ascending ("asc") or descending ("desc")
54+
5155
output_format: str, optional (default='object')
5256
The parameter decides the format of the output.
5357
- If 'object' the output is a dict of OpenMLEvaluation objects
@@ -77,6 +81,7 @@ def list_evaluations(
7781
flow=flow,
7882
uploader=uploader,
7983
tag=tag,
84+
sort=sort,
8085
per_fold=per_fold_str)
8186

8287

@@ -87,6 +92,7 @@ def _list_evaluations(
8792
setup: Optional[List] = None,
8893
flow: Optional[List] = None,
8994
uploader: Optional[List] = None,
95+
sort: Optional[str] = None,
9096
output_format: str = 'object',
9197
**kwargs
9298
) -> Union[Dict, pd.DataFrame]:
@@ -114,6 +120,9 @@ def _list_evaluations(
114120
kwargs: dict, optional
115121
Legal filter operators: tag, limit, offset.
116122
123+
sort : str, optional
124+
order of sorting evaluations, ascending ("asc") or descending ("desc")
125+
117126
output_format: str, optional (default='dict')
118127
The parameter decides the format of the output.
119128
- If 'dict' the output is a dict of dict
@@ -141,6 +150,8 @@ def _list_evaluations(
141150
api_call += "/flow/%s" % ','.join([str(int(i)) for i in flow])
142151
if uploader is not None:
143152
api_call += "/uploader/%s" % ','.join([str(int(i)) for i in uploader])
153+
if sort is not None:
154+
api_call += "/sort/%s" % sort
144155

145156
return __list_evaluations(api_call, output_format=output_format)
146157

tests/test_evaluations/test_evaluation_functions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,28 @@ def test_evaluation_list_per_fold(self):
116116
for run_id in evaluations.keys():
117117
self.assertIsNotNone(evaluations[run_id].value)
118118
self.assertIsNone(evaluations[run_id].values)
119+
120+
def test_evaluation_list_sort(self):
121+
openml.config.server = self.production_server
122+
size = 10
123+
task_id = 1769
124+
# Get all evaluations of the task
125+
unsorted_eval = openml.evaluations.list_evaluations(
126+
"predictive_accuracy", offset=0, task=[task_id])
127+
# Get top 10 evaluations of the same task
128+
sorted_eval = openml.evaluations.list_evaluations(
129+
"predictive_accuracy", size=size, offset=0, task=[task_id], sort="desc")
130+
131+
sorted_output = []
132+
unsorted_output = []
133+
for run_id in sorted_eval.keys():
134+
sorted_output.append(sorted_eval[run_id].value)
135+
for run_id in unsorted_eval.keys():
136+
unsorted_output.append(unsorted_eval[run_id].value)
137+
138+
# Check if output from sort is sorted in the right order
139+
self.assertTrue(sorted(sorted_output, reverse=True) == sorted_output)
140+
141+
# Compare manual sorting against sorted output
142+
test_output = sorted(unsorted_output, reverse=True)
143+
self.assertTrue(test_output[:size] == sorted_output)

0 commit comments

Comments
 (0)