66import warnings
77import sklearn
88import time
9+ import six
10+ import json
911
1012from ..exceptions import PyOpenMLError
1113from .. import config
1517
1618from ..exceptions import OpenMLCacheException , OpenMLServerException
1719from ..util import URLError , version_complies
18- from .._api_calls import _perform_api_call
20+ from .._api_calls import _perform_api_call , fileid_to_url
1921from .run import OpenMLRun , _get_version_information
22+ from .trace import OpenMLRunTrace , OpenMLTraceIteration
2023
2124
2225# _get_version_info, _get_dict and _create_setup_string are in run.py to avoid
@@ -112,6 +115,45 @@ def initialize_model_from_run(run_id):
112115 run = get_run (run_id )
113116 return initialize_model (run .setup_id )
114117
118+ def initialize_model_from_trace (run_id , repeat , fold , iteration = None ):
119+ '''
120+ Initialized a model based on the parameters that were set
121+ by an optimization procedure (i.e., using the exact same
122+ parameter settings)
123+
124+ Parameters
125+ ----------
126+ run_id : int
127+ The Openml run_id. Should contain a trace file
128+
129+ repeat: int
130+ The repeat nr (column in trace file)
131+
132+ fold: int
133+ The fold nr (column in trace file)
134+
135+ iteration: int
136+ The iteration nr (column in trace file)
137+
138+ Returns
139+ -------
140+ model : sklearn model
141+ the scikitlearn model with all parameters initailized
142+ '''
143+ run = get_run (run_id )
144+ if 'trace' not in run .output_files :
145+ raise PyOpenMLError ('Run does not contain trace file' )
146+ trace_url = fileid_to_url (run .output_files ['trace' ], 'trace.arff' )
147+ #print(trace_url)
148+ trace_xml = _perform_api_call ('run/trace/%d' % run_id )
149+ run_trace = _create_trace_from_description (trace_xml )
150+
151+ request = (repeat , fold , iteration )
152+ if request not in run_trace .trace_iterations :
153+ raise ValueError ('Combination repeat, fold, iteration not availavle' )
154+ current = run_trace .trace_iterations [(repeat , fold , iteration )]
155+
156+
115157def _run_exists (task_id , setup_id ):
116158 '''
117159 Checks whether a task/setup combination is already present on the server.
@@ -306,7 +348,7 @@ def _extract_arfftrace(model, rep_no, fold_no):
306348 arff_line = [rep_no , fold_no , itt_no , test_score , selected ]
307349 for key in model .cv_results_ :
308350 if key .startswith ("param_" ):
309- arff_line .append (str (model .cv_results_ [key ][itt_no ]))
351+ arff_line .append (sklearn_to_flow (model .cv_results_ [key ][itt_no ]))
310352 arff_tracecontent .append (arff_line )
311353 return arff_tracecontent
312354
@@ -326,15 +368,20 @@ def _extract_arfftrace_attributes(model):
326368
327369 # model dependent attributes for trace arff
328370 for key in model .cv_results_ :
329- if key .startswith ("param_" ):
371+ if key .startswith ('param_' ):
372+ # supported types should include all types, including bool, int float
373+ supported_types = (bool , int , float , six .string_types )
330374 if all (isinstance (i , (bool )) for i in model .cv_results_ [key ]):
331375 type = ['True' , 'False' ]
332376 elif all (isinstance (i , (int , float )) for i in model .cv_results_ [key ]):
333377 type = 'NUMERIC'
378+ elif all (isinstance (i , supported_types ) or i is None for i in model .cv_results_ [key ]):
379+ type = 'STRING'
334380 else :
335- values = list (set (model .cv_results_ [key ])) # unique values
336- type = [str (i ) for i in values ]
381+ raise TypeError ('Unsupported param type in param grid' )
337382
383+ # we renamed the attribute param to parameter, as this is a required
384+ # OpenML convention
338385 attribute = ("parameter_" + key [6 :], type )
339386 trace_attributes .append (attribute )
340387 return trace_attributes
@@ -439,45 +486,52 @@ def _create_run_from_xml(xml):
439486
440487 dataset_id = int (run ['oml:input_data' ]['oml:dataset' ]['oml:did' ])
441488
442- predictions_url = None
443- if isinstance (run ['oml:output_data' ]['oml:file' ], dict ):
444- # only one result.. probably due to an upload error
445- file_dict = run ['oml:output_data' ]['oml:file' ]
446- if file_dict ['oml:name' ] == 'predictions' :
447- predictions_url = file_dict ['oml:url' ]
448- else :
449- # multiple files, the normal case
450- for file_dict in run ['oml:output_data' ]['oml:file' ]:
451- if file_dict ['oml:name' ] == 'predictions' :
452- predictions_url = file_dict ['oml:url' ]
453- if predictions_url is None :
454- raise ValueError ('No URL to download predictions for run %d in run '
455- 'description XML' % run_id )
489+ files = dict ()
456490 evaluations = dict ()
457491 detailed_evaluations = defaultdict (lambda : defaultdict (dict ))
458- evaluation_flows = dict ()
459- if 'oml:output_data' in run and 'oml:evaluation' in run ['oml:output_data' ]:
460- for evaluation_dict in run ['oml:output_data' ]['oml:evaluation' ]:
461- key = evaluation_dict ['oml:name' ]
462- if 'oml:value' in evaluation_dict :
463- value = float (evaluation_dict ['oml:value' ])
464- elif 'oml:array_data' in evaluation_dict :
465- value = evaluation_dict ['oml:array_data' ]
466- else :
467- raise ValueError ('Could not find keys "value" or "array_data" '
468- 'in %s' % str (evaluation_dict .keys ()))
469-
470- if '@repeat' in evaluation_dict and '@fold' in evaluation_dict :
471- repeat = int (evaluation_dict ['@repeat' ])
472- fold = int (evaluation_dict ['@fold' ])
473- repeat_dict = detailed_evaluations [key ]
474- fold_dict = repeat_dict [repeat ]
475- fold_dict [fold ] = value
476- else :
477- evaluations [key ] = value
478- evaluation_flows [key ] = flow_id
492+ if 'oml:output_data' not in run :
493+ raise ValueError ('Run does not contain output_data (OpenML server error?)' )
494+ else :
495+ if isinstance (run ['oml:output_data' ]['oml:file' ], dict ):
496+ # only one result.. probably due to an upload error
497+ file_dict = run ['oml:output_data' ]['oml:file' ]
498+ files [file_dict ['oml:name' ]] = int (file_dict ['oml:file_id' ])
499+ else :
500+ # multiple files, the normal case
501+ for file_dict in run ['oml:output_data' ]['oml:file' ]:
502+ files [file_dict ['oml:name' ]] = int (file_dict ['oml:file_id' ])
503+ if 'oml:evaluation' in run ['oml:output_data' ]:
504+ # in normal cases there should be evaluations, but in case there
505+ # was an error these could be absent
506+ for evaluation_dict in run ['oml:output_data' ]['oml:evaluation' ]:
507+ key = evaluation_dict ['oml:name' ]
508+ if 'oml:value' in evaluation_dict :
509+ value = float (evaluation_dict ['oml:value' ])
510+ elif 'oml:array_data' in evaluation_dict :
511+ value = evaluation_dict ['oml:array_data' ]
512+ else :
513+ raise ValueError ('Could not find keys "value" or "array_data" '
514+ 'in %s' % str (evaluation_dict .keys ()))
515+
516+ if '@repeat' in evaluation_dict and '@fold' in evaluation_dict :
517+ repeat = int (evaluation_dict ['@repeat' ])
518+ fold = int (evaluation_dict ['@fold' ])
519+ repeat_dict = detailed_evaluations [key ]
520+ fold_dict = repeat_dict [repeat ]
521+ fold_dict [fold ] = value
522+ else :
523+ evaluations [key ] = value
524+
525+ if 'description' not in files :
526+ raise ValueError ('No description file for run %d in run '
527+ 'description XML' % run_id )
528+
529+ if 'predictions' not in files :
530+ # JvR: actually, I am not sure whether this error should be raised.
531+ # a run can consist without predictions. But for now let's keep it
532+ raise ValueError ('No prediction files for run %d in run '
533+ 'description XML' % run_id )
479534
480- evaluation_flows [key ] = flow_id
481535 tags = None
482536 if 'oml:tag' in run :
483537 if isinstance (run ['oml:tag' ], str ):
@@ -487,18 +541,40 @@ def _create_run_from_xml(xml):
487541 else :
488542 raise ValueError ('Received not string and non list as tag item' )
489543
490-
491544 return OpenMLRun (run_id = run_id , uploader = uploader ,
492545 uploader_name = uploader_name , task_id = task_id ,
493546 task_type = task_type ,
494547 task_evaluation_measure = task_evaluation_measure ,
495548 flow_id = flow_id , flow_name = flow_name ,
496549 setup_id = setup_id , setup_string = setup_string ,
497550 parameter_settings = parameters ,
498- dataset_id = dataset_id , predictions_url = predictions_url ,
551+ dataset_id = dataset_id , output_files = files ,
499552 evaluations = evaluations ,
500553 detailed_evaluations = detailed_evaluations , tags = tags )
501554
555+ def _create_trace_from_description (xml ):
556+ result_dict = xmltodict .parse (xml )['oml:trace' ]
557+
558+ run_id = result_dict ['oml:run_id' ]
559+ trace = dict ()
560+
561+ if 'oml:trace_iteration' not in result_dict :
562+ raise ValueError ('Run does not contain valid trace. ' )
563+
564+ for itt in result_dict ['oml:trace_iteration' ]:
565+ repeat = int (itt ['oml:repeat' ])
566+ fold = int (itt ['oml:fold' ])
567+ iteration = int (itt ['oml:iteration' ])
568+ setup_string = json .loads (itt ['oml:setup_string' ])
569+ evaluation = float (itt ['oml:evaluation' ])
570+ selected = bool (itt ['oml:selected' ])
571+
572+ current = OpenMLTraceIteration (repeat , fold , iteration ,
573+ setup_string , evaluation ,
574+ selected )
575+ trace [(repeat , fold , iteration )] = current
576+
577+ return OpenMLRunTrace (run_id , trace )
502578
503579def _get_cached_run (run_id ):
504580 """Load a run from the cache."""
0 commit comments