Skip to content

Commit 58a6609

Browse files
committed
Fixing edge cases to pass tests
1 parent 6dc4345 commit 58a6609

4 files changed

Lines changed: 144 additions & 100 deletions

File tree

openml/extensions/sklearn/extension.py

Lines changed: 105 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
501501
def match_format(s):
502502
return "{}\n{}\n".format(s, len(s) * '-')
503503
s = inspect.getdoc(model)
504+
if s is None:
505+
return ''
504506
if len(s) <= char_lim:
505507
# if the fetched docstring is smaller than char_lim, no trimming required
506508
return s.strip()
@@ -528,6 +530,105 @@ def match_format(s):
528530
s = "{}...".format(s[:char_lim - 3])
529531
return s.strip()
530532

533+
def _extract_sklearn_parameter_docstring(self, model) -> Union[None, str]:
534+
'''Extracts the part of sklearn docstring containing parameter information
535+
536+
Fetches the entire docstring and trims just the Parameter section.
537+
The assumption is that 'Parameters' is the first section in sklearn docstrings,
538+
followed by other sections titled 'Attributes', 'See also', 'Note', 'References',
539+
appearing in that order if defined.
540+
Returns a None if no section with 'Parameters' can be found in the docstring.
541+
542+
Parameters
543+
----------
544+
model : sklearn model
545+
546+
Returns
547+
-------
548+
str, or None
549+
'''
550+
def match_format(s):
551+
return "{}\n{}\n".format(s, len(s) * '-')
552+
s = inspect.getdoc(model)
553+
if s is None:
554+
return None
555+
try:
556+
index1 = s.index(match_format("Parameters"))
557+
except ValueError as e:
558+
# when sklearn docstring has no 'Parameters' section
559+
print("{} {}".format(match_format("Parameters"), e))
560+
return None
561+
562+
headings = ["Attributes", "Notes", "See also", "Note", "References"]
563+
for h in headings:
564+
try:
565+
# to find end of Parameters section
566+
index2 = s.index(match_format(h))
567+
break
568+
except ValueError:
569+
print("{} not available in docstring".format(h))
570+
continue
571+
else:
572+
# in the case only 'Parameters' exist, trim till end of docstring
573+
index2 = len(s)
574+
s = s[index1:index2]
575+
return s.strip()
576+
577+
def _extract_sklearn_param_info(self, model, char_lim=1024) -> Union[None, Dict]:
578+
'''Parses parameter type and description from sklearn dosctring
579+
580+
Parameters
581+
----------
582+
model : sklearn model
583+
char_lim : int
584+
Specifying the max length of the returned string.
585+
OpenML servers have a constraint of 1024 characters string fields.
586+
587+
Returns
588+
-------
589+
Dict, or None
590+
'''
591+
docstring = self._extract_sklearn_parameter_docstring(model)
592+
if docstring is None:
593+
# when sklearn docstring has no 'Parameters' section
594+
return None
595+
596+
n = re.compile("[.]*\n", flags=IGNORECASE)
597+
lines = n.split(docstring)
598+
p = re.compile("[a-z0-9_ ]+ : [a-z0-9_']+[a-z0-9_ ]*", flags=IGNORECASE)
599+
parameter_docs = OrderedDict() # type: Dict
600+
description = [] # type: List
601+
602+
# collecting parameters and their descriptions
603+
for i, s in enumerate(lines):
604+
param = p.findall(s)
605+
if param != []:
606+
if len(description) > 0:
607+
description[-1] = '\n'.join(description[-1]).strip()
608+
if len(description[-1]) > char_lim:
609+
description[-1] = "{}...".format(description[-1][:char_lim - 3])
610+
description.append([])
611+
else:
612+
if len(description) > 0:
613+
description[-1].append(s)
614+
description[-1] = '\n'.join(description[-1]).strip()
615+
if len(description[-1]) > char_lim:
616+
description[-1] = "{}...".format(description[-1][:char_lim - 3])
617+
618+
# collecting parameters and their types
619+
matches = p.findall(docstring)
620+
for i, param in enumerate(matches):
621+
key, value = param.split(':')
622+
parameter_docs[key.strip()] = [value.strip(), description[i]]
623+
624+
# to avoid KeyError for missing parameters
625+
param_list_true = list(model.get_params().keys())
626+
param_list_found = list(parameter_docs.keys())
627+
for param in list(set(param_list_true) - set(param_list_found)):
628+
parameter_docs[param] = [None, None]
629+
630+
return parameter_docs
631+
531632
def _serialize_model(self, model: Any) -> OpenMLFlow:
532633
"""Create an OpenMLFlow.
533634
@@ -656,97 +757,6 @@ def _check_multiple_occurence_of_component_in_flow(
656757
known_sub_components.add(visitee.name)
657758
to_visit_stack.extend(visitee.components.values())
658759

659-
def _extract_sklearn_parameter_docstring(self, model) -> Union[None, str]:
660-
'''Extracts the part of sklearn docstring containing parameter information
661-
662-
Fetches the entire docstring and trims just the Parameter section.
663-
The assumption is that 'Parameters' is the first section in sklearn docstrings,
664-
followed by other sections titled 'Attributes', 'See also', 'Note', 'References',
665-
appearing in that order if defined.
666-
Returns a None if no section with 'Parameters' can be found in the docstring.
667-
668-
Parameters
669-
----------
670-
model : sklearn model
671-
672-
Returns
673-
-------
674-
str, or None
675-
'''
676-
def match_format(s):
677-
return "{}\n{}\n".format(s, len(s) * '-')
678-
s = inspect.getdoc(model)
679-
try:
680-
index1 = s.index(match_format("Parameters"))
681-
except ValueError as e:
682-
# when sklearn docstring has no 'Parameters' section
683-
print("{} {}".format(match_format("Parameters"), e))
684-
return None
685-
686-
headings = ["Attributes", "Notes", "See also", "Note", "References"]
687-
for h in headings:
688-
try:
689-
# to find end of Parameters section
690-
index2 = s.index(match_format(h))
691-
break
692-
except ValueError:
693-
print("{} not available in docstring".format(h))
694-
continue
695-
else:
696-
# in the case only 'Parameters' exist, trim till end of docstring
697-
index2 = len(s)
698-
s = s[index1:index2]
699-
return s.strip()
700-
701-
def _extract_sklearn_param_info(self, model, char_lim=1024) -> Union[None, Dict]:
702-
'''Parses parameter type and description from sklearn dosctring
703-
704-
Parameters
705-
----------
706-
model : sklearn model
707-
char_lim : int
708-
Specifying the max length of the returned string.
709-
OpenML servers have a constraint of 1024 characters string fields.
710-
711-
Returns
712-
-------
713-
Dict, or None
714-
'''
715-
docstring = self._extract_sklearn_parameter_docstring(model)
716-
if docstring is None:
717-
# when sklearn docstring has no 'Parameters' section
718-
return None
719-
720-
n = re.compile("[.]*\n", flags=IGNORECASE)
721-
lines = n.split(docstring)
722-
p = re.compile("[a-z0-9_ ]+ : [a-z0-9_']+[a-z0-9_ ]*", flags=IGNORECASE)
723-
parameter_docs = OrderedDict() # type: Dict
724-
description = [] # type: List
725-
726-
# collecting parameters and their descriptions
727-
for i, s in enumerate(lines):
728-
param = p.findall(s)
729-
if param != []:
730-
if len(description) > 0:
731-
description[-1] = '\n'.join(description[-1]).strip()
732-
if len(description[-1]) > char_lim:
733-
description[-1] = "{}...".format(description[-1][:char_lim - 3])
734-
description.append([])
735-
else:
736-
if len(description) > 0:
737-
description[-1].append(s)
738-
description[-1] = '\n'.join(description[-1]).strip()
739-
if len(description[-1]) > char_lim:
740-
description[-1] = "{}...".format(description[-1][:char_lim - 3])
741-
742-
# collecting parameters and their types
743-
matches = p.findall(docstring)
744-
for i, param in enumerate(matches):
745-
key, value = param.split(':')
746-
parameter_docs[key.strip()] = [value.strip(), description[i]]
747-
748-
return parameter_docs
749-
750760
def _extract_information_from_model(
751761
self,
752762
model: Any,
@@ -890,6 +900,10 @@ def flatten_all(list_):
890900
parameters[k] = None
891901

892902
if parameters_docs is not None:
903+
# print(type(model))
904+
# print(sorted(parameters_docs.keys()))
905+
# print(sorted(model_parameters.keys()))
906+
# print()
893907
data_type, description = parameters_docs[k]
894908
parameters_meta_info[k] = OrderedDict((('description', description),
895909
('data_type', data_type)))

openml/flows/functions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ def assert_flows_equal(flow1: OpenMLFlow, flow2: OpenMLFlow,
366366
ignore_custom_name_if_none)
367367
elif key == '_extension':
368368
continue
369+
elif key == 'description':
370+
# to ignore matching of descriptions since sklearn based flows may have
371+
# altering docstrings and is not guaranteed to be consistent
372+
continue
369373
else:
370374
if key == 'parameters':
371375
if ignore_parameter_values or \
@@ -397,6 +401,33 @@ def assert_flows_equal(flow1: OpenMLFlow, flow2: OpenMLFlow,
397401
# Helps with backwards compatibility as `custom_name` is now auto-generated, but
398402
# before it used to be `None`.
399403
continue
404+
elif key == 'parameters_meta_info':
405+
# this value is a dictionary where each key is a parameter name, containing another
406+
# dictionary with keys specifying the parameter's 'description' and 'data_type'
407+
# check of descriptions can be ignored since that might change
408+
# data type check can be ignored if one of them is not defined, i.e., None
409+
params1 = set(flow1.parameters_meta_info.keys())
410+
params2 = set(flow2.parameters_meta_info.keys())
411+
if params1 != params2:
412+
raise ValueError('Parameter list in meta info for parameters differ in the two flows.')
413+
# iterating over the parameter's meta info list
414+
for param in params1:
415+
if isinstance(flow1.parameters_meta_info[param], Dict) and \
416+
isinstance(flow2.parameters_meta_info[param], Dict) and \
417+
'data_type' in flow1.parameters_meta_info[param] and \
418+
'data_type' in flow2.parameters_meta_info[param]:
419+
value1 = flow1.parameters_meta_info[param]['data_type']
420+
value2 = flow2.parameters_meta_info[param]['data_type']
421+
else:
422+
value1 = flow1.parameters_meta_info[param]
423+
value2 = flow2.parameters_meta_info[param]
424+
if value1 is None or value2 is None:
425+
continue
426+
elif value1 != value2:
427+
raise ValueError("Flow {}: data type for parameter {} in parameters_meta_info differ as "
428+
"{}\nvs\n{}".format(flow1.name, key, value1, value2))
429+
# the continue is to avoid the 'attr != attr2' check at end of function
430+
continue
400431

401432
if attr1 != attr2:
402433
raise ValueError("Flow %s: values for attribute '%s' differ: "

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_serialize_model(self):
7575

7676
fixture_name = 'sklearn.tree.tree.DecisionTreeClassifier'
7777
fixture_short_name = 'sklearn.DecisionTreeClassifier'
78-
fixture_description = 'Automatically created scikit-learn flow.'
78+
fixture_description = self.extension._get_sklearn_description(model)
7979
version_fixture = 'sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9' \
8080
% sklearn.__version__
8181
# min_impurity_decrease has been introduced in 0.20
@@ -143,7 +143,7 @@ def test_serialize_model_clustering(self):
143143

144144
fixture_name = 'sklearn.cluster.k_means_.KMeans'
145145
fixture_short_name = 'sklearn.KMeans'
146-
fixture_description = 'Automatically created scikit-learn flow.'
146+
fixture_description = self.extension._get_sklearn_description(model)
147147
version_fixture = 'sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9' \
148148
% sklearn.__version__
149149
# n_jobs default has changed to None in 0.20
@@ -207,10 +207,10 @@ def test_serialize_model_with_subcomponent(self):
207207
'(base_estimator=sklearn.tree.tree.DecisionTreeClassifier)'
208208
fixture_class_name = 'sklearn.ensemble.weight_boosting.AdaBoostClassifier'
209209
fixture_short_name = 'sklearn.AdaBoostClassifier'
210-
fixture_description = 'Automatically created scikit-learn flow.'
210+
fixture_description = self.extension._get_sklearn_description(model)
211211
fixture_subcomponent_name = 'sklearn.tree.tree.DecisionTreeClassifier'
212212
fixture_subcomponent_class_name = 'sklearn.tree.tree.DecisionTreeClassifier'
213-
fixture_subcomponent_description = 'Automatically created scikit-learn flow.'
213+
fixture_subcomponent_description = self.extension._get_sklearn_description(model.base_estimator)
214214
fixture_structure = {
215215
fixture_name: [],
216216
'sklearn.tree.tree.DecisionTreeClassifier': ['base_estimator']
@@ -264,7 +264,7 @@ def test_serialize_pipeline(self):
264264
'scaler=sklearn.preprocessing.data.StandardScaler,' \
265265
'dummy=sklearn.dummy.DummyClassifier)'
266266
fixture_short_name = 'sklearn.Pipeline(StandardScaler,DummyClassifier)'
267-
fixture_description = 'Automatically created scikit-learn flow.'
267+
fixture_description = self.extension._get_sklearn_description(model)
268268
fixture_structure = {
269269
fixture_name: [],
270270
'sklearn.preprocessing.data.StandardScaler': ['scaler'],
@@ -353,7 +353,7 @@ def test_serialize_pipeline_clustering(self):
353353
'scaler=sklearn.preprocessing.data.StandardScaler,' \
354354
'clusterer=sklearn.cluster.k_means_.KMeans)'
355355
fixture_short_name = 'sklearn.Pipeline(StandardScaler,KMeans)'
356-
fixture_description = 'Automatically created scikit-learn flow.'
356+
fixture_description = self.extension._get_sklearn_description(model)
357357
fixture_structure = {
358358
fixture_name: [],
359359
'sklearn.preprocessing.data.StandardScaler': ['scaler'],
@@ -445,7 +445,7 @@ def test_serialize_column_transformer(self):
445445
'numeric=sklearn.preprocessing.data.StandardScaler,' \
446446
'nominal=sklearn.preprocessing._encoders.OneHotEncoder)'
447447
fixture_short_name = 'sklearn.ColumnTransformer'
448-
fixture_description = 'Automatically created scikit-learn flow.'
448+
fixture_description = self.extension._get_sklearn_description(model)
449449
fixture_structure = {
450450
fixture: [],
451451
'sklearn.preprocessing.data.StandardScaler': ['numeric'],
@@ -504,7 +504,7 @@ def test_serialize_column_transformer_pipeline(self):
504504
fixture_name: [],
505505
}
506506

507-
fixture_description = 'Automatically created scikit-learn flow.'
507+
fixture_description = self.extension._get_sklearn_description(model)
508508
serialization = self.extension.model_to_flow(model)
509509
structure = serialization.get_structure('name')
510510
self.assertEqual(serialization.name, fixture_name)

tests/test_flows/test_flow_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def test_are_flows_equal(self):
9595
# Test most important values that can be set by a user
9696
openml.flows.functions.assert_flows_equal(flow, flow)
9797
for attribute, new_value in [('name', 'Tes'),
98-
('description', 'Test flo'),
9998
('external_version', '2'),
10099
('language', 'english'),
101100
('dependencies', 'ab'),

0 commit comments

Comments
 (0)