Skip to content

Commit d90f333

Browse files
committed
More robust failure checks + improved docstrings
1 parent b0ad048 commit d90f333

1 file changed

Lines changed: 73 additions & 23 deletions

File tree

openml/extensions/sklearn/extension.py

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -480,31 +480,48 @@ def _is_sklearn_flow(cls, flow: OpenMLFlow) -> bool:
480480
def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
481481
'''Fetches the sklearn function docstring for the flow description
482482
483+
Retrieves the sklearn docstring available and does the following:
484+
* If length of docstring <= char_lim, then returns the complete docstring
485+
* Else, trims the docstring till it encounters a 'Read more in the :ref:'
486+
* Or till it encounters a 'Parameters\n----------\n'
487+
The final string returned is at most of length char_lim with leading and
488+
trailing whitespaces removed.
489+
483490
Parameters
484491
----------
485-
model: The sklearn model object
486-
char_lim: int, specifying the max length of the returned string
492+
model : sklearn model
493+
char_lim : int
494+
Specifying the max length of the returned string
487495
OpenML servers have a constraint of 1024 characters for the 'description' field.
488496
489497
Returns
490498
-------
491-
string of length <= char_lim
499+
str
492500
'''
493501
def match_format(s):
494502
return "{}\n{}\n".format(s, len(s) * '-')
495503
s = inspect.getdoc(model)
496504
if len(s) <= char_lim:
497-
return s
505+
# if the fetched docstring is smaller than char_lim, no trimming required
506+
return s.strip()
498507
try:
499-
pattern = "Read more in the :ref:" # "Parameters"
508+
# trim till 'Read more'
509+
pattern = "Read more in the :ref:"
500510
index = s.index(pattern)
501511
except ValueError:
512+
pass
513+
try:
514+
# if 'Read more' doesn't exist, trim till 'Parameters'
502515
pattern = "Parameters"
503516
index = s.index(match_format(pattern))
517+
except ValueError:
518+
# returning full docstring
519+
index = len(s)
504520
s = s[:index]
521+
# trimming docstring to be within char_lim
505522
if len(s) > char_lim:
506523
s = "{}...".format(s[:char_lim - 3])
507-
return s
524+
return s.strip()
508525

509526
def _serialize_model(self, model: Any) -> OpenMLFlow:
510527
"""Create an OpenMLFlow.
@@ -634,38 +651,69 @@ def _check_multiple_occurence_of_component_in_flow(
634651
known_sub_components.add(visitee.name)
635652
to_visit_stack.extend(visitee.components.values())
636653

637-
def _extract_sklearn_parameter_docstring(self, model):
654+
def _extract_sklearn_parameter_docstring(self, model) -> Union[None, str]:
655+
'''Extracts the part of sklearn docstring containing parameter information
656+
657+
Fetches the entire docstring and trims just the Parameter section.
658+
The assumption is that 'Parameters' is the first section in sklearn docstrings,
659+
followed by other sections titled 'Attributes', 'See also', 'Note', 'References',
660+
appearing in that order if defined.
661+
Returns a None if no section with 'Parameters' can be found in the docstring.
662+
663+
Parameters
664+
----------
665+
model : sklearn model
666+
667+
Returns
668+
-------
669+
str, or None
670+
'''
638671
def match_format(s):
639672
return "{}\n{}\n".format(s, len(s) * '-')
640673
s = inspect.getdoc(model)
641-
s1 = "Parameters"
642-
s2 = ["Attributes", "See also", "Note", "References"]
643674
try:
644-
index1 = s.index(match_format(s1))
675+
index1 = s.index(match_format("Parameters"))
645676
except ValueError as e:
646-
print("Parameter {}".format(e))
647-
# returns the whole sklearn docstring available
648-
return s
649-
for h in s2:
677+
# when sklearn docstring has no 'Parameters' section
678+
print("{} {}".format(match_format("Parameters"), e))
679+
return None
680+
681+
headings = ["Attributes", "See also", "Note", "References"]
682+
for h in headings:
650683
try:
684+
# to find end of Parameters section
651685
index2 = s.index(match_format(h))
652686
break
653687
except ValueError:
654688
print("{} not available in docstring".format(h))
655689
continue
656690
else:
657-
# in the case only 'Parameters' exist
691+
# in the case only 'Parameters' exist, trim till end of docstring
658692
index2 = len(s)
659693
s = s[index1:index2]
660-
return s
694+
return s.strip()
695+
696+
def _extract_sklearn_param_info(self, model) -> Union[None, Dict]:
697+
'''Parses parameter type and description from sklearn dosctring
698+
699+
Parameters
700+
----------
701+
model : sklearn model
661702
662-
def _extract_sklearn_param_info(self, model):
703+
Returns
704+
-------
705+
Dict, or None
706+
'''
663707
docstring = self._extract_sklearn_parameter_docstring(model)
708+
if docstring is None:
709+
# when sklearn docstring has no 'Parameters' section
710+
return None
711+
664712
n = re.compile("[.]*\n", flags=IGNORECASE)
665713
lines = n.split(docstring)
666714
p = re.compile("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*", flags=IGNORECASE)
667-
parameter_docs = OrderedDict()
668-
description = []
715+
parameter_docs = OrderedDict() # type: Dict
716+
description = [] # type: List
669717

670718
# collecting parameters and their descriptions
671719
for i, s in enumerate(lines):
@@ -681,7 +729,6 @@ def _extract_sklearn_param_info(self, model):
681729

682730
# collecting parameters and their types
683731
matches = p.findall(docstring)
684-
parameter_docs = OrderedDict()
685732
for i, param in enumerate(matches):
686733
key, value = param.split(':')
687734
parameter_docs[key.strip()] = [value.strip(), description[i]]
@@ -830,9 +877,12 @@ def flatten_all(list_):
830877
else:
831878
parameters[k] = None
832879

833-
data_type, description = parameters_docs[k]
834-
parameters_meta_info[k] = OrderedDict((('description', description),
835-
('data_type', data_type)))
880+
if parameters_docs is not None:
881+
data_type, description = parameters_docs[k]
882+
parameters_meta_info[k] = OrderedDict((('description', description),
883+
('data_type', data_type)))
884+
else:
885+
parameters_meta_info[k] = OrderedDict((('description', None), ('data_type', None)))
836886

837887
return parameters, parameters_meta_info, sub_components, sub_components_explicit
838888

0 commit comments

Comments
 (0)