@@ -23,21 +23,19 @@ def _get_cached_tasks():
2323 # description
2424
2525 for filename in directory_content :
26- match = re .match (r"(tid)_([0-9]*)\.xml" , filename )
27- if match :
28- tid = match .group (2 )
29- tid = int (tid )
26+ if not re .match (r"[0-9]*" , filename ):
27+ continue
3028
31- tasks [tid ] = _get_cached_task (tid )
29+ tid = int (filename )
30+ tasks [tid ] = _get_cached_task (tid )
3231
3332 return tasks
3433
3534
3635def _get_cached_task (tid ):
3736 for cache_dir in [config .get_cache_directory (), config .get_private_directory ()]:
3837 task_cache_dir = os .path .join (cache_dir , "tasks" )
39- task_file = os .path .join (task_cache_dir ,
40- "tid_%d.xml" % int (tid ))
38+ task_file = os .path .join (task_cache_dir , str (tid ), "task.xml" )
4139
4240 try :
4341 with open (task_file ) as fh :
@@ -50,7 +48,7 @@ def _get_cached_task(tid):
5048 "cached" % tid )
5149
5250
53- def get_estimation_procedure_list ():
51+ def _get_estimation_procedure_list ():
5452 """Return a list of all estimation procedures which are on OpenML.
5553
5654 Returns
@@ -65,9 +63,18 @@ def get_estimation_procedure_list():
6563 "estimationprocedure/list" )
6664 procs_dict = xmltodict .parse (xml_string )
6765 # Minimalistic check if the XML is useful
68- assert procs_dict ['oml:estimationprocedures' ]['@xmlns:oml' ] == \
69- 'http://openml.org/openml'
70- assert type (procs_dict ['oml:estimationprocedures' ]['oml:estimationprocedure' ]) == list
66+ if 'oml:estimationprocedures' not in procs_dict :
67+ raise ValueError ('Error in return XML, does not contain tag '
68+ 'oml:estimationprocedures.' )
69+ elif '@xmlns:oml' not in procs_dict ['oml:estimationprocedures' ]:
70+ raise ValueError ('Error in return XML, does not contain tag '
71+ '@xmlns:oml as a child of oml:estimationprocedures.' )
72+ elif procs_dict ['oml:estimationprocedures' ]['@xmlns:oml' ] != \
73+ 'http://openml.org/openml' :
74+ raise ValueError ('Error in return XML, value of '
75+ 'oml:estimationprocedures/@xmlns:oml is not '
76+ 'http://openml.org/openml, but %s' %
77+ str (procs_dict ['oml:estimationprocedures' ]['@xmlns:oml' ]))
7178
7279 procs = []
7380 for proc_ in procs_dict ['oml:estimationprocedures' ]['oml:estimationprocedure' ]:
@@ -156,7 +163,7 @@ def _list_tasks(api_call):
156163 % str (tasks_dict ))
157164 try :
158165 tasks = []
159- procs = get_estimation_procedure_list ()
166+ procs = _get_estimation_procedure_list ()
160167 proc_dict = dict ((x ['id' ], x ) for x in procs )
161168 for task_ in tasks_dict ['oml:tasks' ]['oml:task' ]:
162169 task = {'tid' : int (task_ ['oml:task_id' ]),
@@ -217,21 +224,12 @@ def get_task(task_id):
217224 print (e )
218225 raise e
219226
220- # Cache the xml task file
221- if os .path .exists (xml_file ):
222- with open (xml_file ) as fh :
223- local_xml = fh .read ()
224-
225- if task_xml != local_xml :
226- raise ValueError ("Task description of task %d cached at %s "
227- "has changed." % (task_id , xml_file ))
228-
229- else :
230- with open (xml_file , "w" ) as fh :
231- fh .write (task_xml )
227+ with open (xml_file , "w" ) as fh :
228+ fh .write (task_xml )
232229
233230 task = _create_task_from_xml (task_xml )
234231
232+ # TODO extract this to a function
235233 task .download_split ()
236234 dataset = datasets .get_dataset (task .dataset_id )
237235
0 commit comments