1212import openml ._api_calls
1313from ..exceptions import PyOpenMLError
1414from ..flows import get_flow
15- from ..tasks import get_task , TaskTypeEnum
15+ from ..tasks import (get_task ,
16+ TaskTypeEnum ,
17+ OpenMLClassificationTask ,
18+ OpenMLLearningCurveTask ,
19+ OpenMLClusteringTask ,
20+ OpenMLRegressionTask
21+ )
1622from ..utils import _tag_entity
1723
1824
@@ -69,7 +75,7 @@ def _repr_pretty_(self, pp, cycle):
6975 pp .text (str (self ))
7076
7177 @classmethod
72- def from_filesystem (cls , directory , expect_model = True ):
78+ def from_filesystem (cls , directory : str , expect_model : bool = True ) -> 'OpenMLRun' :
7379 """
7480 The inverse of the to_filesystem method. Instantiates an OpenMLRun
7581 object based on files stored on the file system.
@@ -109,24 +115,24 @@ def from_filesystem(cls, directory, expect_model=True):
109115 if not os .path .isfile (model_path ) and expect_model :
110116 raise ValueError ('Could not find model.pkl' )
111117
112- with open (description_path , 'r' ) as fp :
113- xml_string = fp .read ()
118+ with open (description_path , 'r' ) as fht :
119+ xml_string = fht .read ()
114120 run = openml .runs .functions ._create_run_from_xml (xml_string , from_server = False )
115121
116122 if run .flow_id is None :
117123 flow = openml .flows .OpenMLFlow .from_filesystem (directory )
118124 run .flow = flow
119125 run .flow_name = flow .name
120126
121- with open (predictions_path , 'r' ) as fp :
122- predictions = arff .load (fp )
127+ with open (predictions_path , 'r' ) as fht :
128+ predictions = arff .load (fht )
123129 run .data_content = predictions ['data' ]
124130
125131 if os .path .isfile (model_path ):
126132 # note that it will load the model if the file exists, even if
127133 # expect_model is False
128- with open (model_path , 'rb' ) as fp :
129- run .model = pickle .load (fp )
134+ with open (model_path , 'rb' ) as fhb :
135+ run .model = pickle .load (fhb )
130136
131137 if os .path .isfile (trace_path ):
132138 run .trace = openml .runs .OpenMLRunTrace ._from_filesystem (trace_path )
@@ -209,7 +215,18 @@ def _generate_arff_dict(self) -> 'OrderedDict[str, Any]':
209215 arff_dict ['relation' ] = \
210216 'openml_task_{}_predictions' .format (task .task_id )
211217
212- if task .task_type_id == TaskTypeEnum .SUPERVISED_CLASSIFICATION :
218+ if isinstance (task , OpenMLLearningCurveTask ):
219+ class_labels = task .class_labels # type: ignore
220+ arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
221+ ('fold' , 'NUMERIC' ),
222+ ('sample' , 'NUMERIC' ),
223+ ('row_id' , 'NUMERIC' )] + \
224+ [('confidence.' + class_labels [i ],
225+ 'NUMERIC' ) for i in
226+ range (len (class_labels ))] + \
227+ [('prediction' , class_labels ),
228+ ('correct' , class_labels )]
229+ elif isinstance (task , OpenMLClassificationTask ):
213230 class_labels = task .class_labels
214231 instance_specifications = [('repeat' , 'NUMERIC' ),
215232 ('fold' , 'NUMERIC' ),
@@ -223,27 +240,14 @@ def _generate_arff_dict(self) -> 'OrderedDict[str, Any]':
223240 arff_dict ['attributes' ] = (instance_specifications
224241 + prediction_confidences
225242 + prediction_and_true )
226-
227- elif task .task_type_id == TaskTypeEnum .LEARNING_CURVE :
228- class_labels = task .class_labels
229- arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
230- ('fold' , 'NUMERIC' ),
231- ('sample' , 'NUMERIC' ),
232- ('row_id' , 'NUMERIC' )] + \
233- [('confidence.' + class_labels [i ],
234- 'NUMERIC' ) for i in
235- range (len (class_labels ))] + \
236- [('prediction' , class_labels ),
237- ('correct' , class_labels )]
238-
239- elif task .task_type_id == TaskTypeEnum .SUPERVISED_REGRESSION :
243+ elif isinstance (task , OpenMLRegressionTask ):
240244 arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
241245 ('fold' , 'NUMERIC' ),
242246 ('row_id' , 'NUMERIC' ),
243247 ('prediction' , 'NUMERIC' ),
244248 ('truth' , 'NUMERIC' )]
245249
246- elif task . task_type == TaskTypeEnum . CLUSTERING :
250+ elif isinstance ( task , OpenMLClusteringTask ) :
247251 arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
248252 ('fold' , 'NUMERIC' ),
249253 ('row_id' , 'NUMERIC' ),
@@ -461,7 +465,7 @@ def _create_description_xml(self):
461465 description_xml = xmltodict .unparse (description , pretty = True )
462466 return description_xml
463467
464- def push_tag (self , tag ) :
468+ def push_tag (self , tag : str ) -> None :
465469 """Annotates this run with a tag on the server.
466470
467471 Parameters
@@ -471,7 +475,7 @@ def push_tag(self, tag):
471475 """
472476 _tag_entity ('run' , self .run_id , tag )
473477
474- def remove_tag (self , tag ) :
478+ def remove_tag (self , tag : str ) -> None :
475479 """Removes a tag from this run on the server.
476480
477481 Parameters
0 commit comments