Skip to content

Commit 1a3f456

Browse files
committed
Intermediate changes; pipeline additions remain
1 parent 2796b9a commit 1a3f456

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

examples/40_paper/2018_neurips_perrone_example.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
a tabular format that can be used to build models.
4040
"""
4141

42-
def fetch_evaluations(run_full=False, flow_type='svm', metric = 'area_under_roc_curve'):
42+
def fetch_evaluations(run_full=False,
43+
flow_type='svm',
44+
metric='area_under_roc_curve'):
4345
'''
4446
Fetch a list of evaluations based on the flows and tasks used in the experiments.
4547
@@ -77,13 +79,19 @@ def fetch_evaluations(run_full=False, flow_type='svm', metric = 'area_under_roc_
7779
flow_id = 5891 if flow_type == 'svm' else 6767
7880

7981
# Fetching evaluations
80-
eval_df = openml.evaluations.list_evaluations(function=metric, task=task_ids, flow=[flow_id],
81-
uploader=[2702], output_format='dataframe')
82+
eval_df = openml.evaluations.list_evaluations(function=metric,
83+
task=task_ids,
84+
flow=[flow_id],
85+
uploader=[2702],
86+
output_format='dataframe')
8287
return eval_df, task_ids, flow_id
8388

8489

85-
def create_table_from_evaluations(eval_df, flow_type='svm', run_count=np.iinfo(np.int64).max,
86-
metric = 'area_under_roc_curve', task_ids=None):
90+
def create_table_from_evaluations(eval_df,
91+
flow_type='svm',
92+
run_count=np.iinfo(np.int64).max,
93+
metric = 'area_under_roc_curve',
94+
task_ids=None):
8795
'''
8896
Create a tabular data with its ground truth from a dataframe of evaluations.
8997
Optionally, can filter out records based on task ids.
@@ -108,7 +116,6 @@ def create_table_from_evaluations(eval_df, flow_type='svm', run_count=np.iinfo(n
108116
'''
109117
if task_ids is not None:
110118
eval_df = eval_df.loc[eval_df.task_id.isin(task_ids)]
111-
ncols = 4 if flow_type == 'svm' else 10 # ncols determine the number of hyperparameters
112119
if flow_type == 'svm':
113120
ncols = 4
114121
colnames = ['cost', 'degree', 'gamma', 'kernel']
@@ -165,6 +172,8 @@ def preprocess(eval_table, flow_type='svm'):
165172
eval_df, task_ids, flow_id = fetch_evaluations(run_full=False)
166173
X, y = create_table_from_evaluations(eval_df, run_count=1000)
167174
X = preprocess(X)
175+
print("Type: {}; Shape: {}".format(type(X), X.shape))
176+
print(X[:5])
168177

169178

170179
#############################################################################

0 commit comments

Comments
 (0)