Skip to content

Commit c39b9f7

Browse files
authored
Merge pull request #756 from openml/fix_175
Adding sklearn docstring to flow
2 parents 27521ac + 7d685e1 commit c39b9f7

4 files changed

Lines changed: 296 additions & 13 deletions

File tree

openml/extensions/sklearn/extension.py

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
import re
9+
from re import IGNORECASE
910
import sys
1011
import time
1112
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -476,6 +477,167 @@ def _is_sklearn_flow(cls, flow: OpenMLFlow) -> bool:
476477
or ',sklearn==' in flow.external_version
477478
)
478479

480+
def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
481+
'''Fetches the sklearn function docstring for the flow description
482+
483+
Retrieves the sklearn docstring available and does the following:
484+
* If length of docstring <= char_lim, then returns the complete docstring
485+
* Else, trims the docstring till it encounters a 'Read more in the :ref:'
486+
* Or till it encounters a 'Parameters\n----------\n'
487+
The final string returned is at most of length char_lim with leading and
488+
trailing whitespaces removed.
489+
490+
Parameters
491+
----------
492+
model : sklearn model
493+
char_lim : int
494+
Specifying the max length of the returned string.
495+
OpenML servers have a constraint of 1024 characters for the 'description' field.
496+
497+
Returns
498+
-------
499+
str
500+
'''
501+
def match_format(s):
502+
return "{}\n{}\n".format(s, len(s) * '-')
503+
s = inspect.getdoc(model)
504+
if s is None:
505+
return ''
506+
try:
507+
# trim till 'Read more'
508+
pattern = "Read more in the :ref:"
509+
index = s.index(pattern)
510+
s = s[:index]
511+
# trimming docstring to be within char_lim
512+
if len(s) > char_lim:
513+
s = "{}...".format(s[:char_lim - 3])
514+
return s.strip()
515+
except ValueError:
516+
logging.warning("'Read more' not found in descriptions. "
517+
"Trying to trim till 'Parameters' if available in docstring.")
518+
pass
519+
try:
520+
# if 'Read more' doesn't exist, trim till 'Parameters'
521+
pattern = "Parameters"
522+
index = s.index(match_format(pattern))
523+
except ValueError:
524+
# returning full docstring
525+
logging.warning("'Parameters' not found in docstring. Omitting docstring trimming.")
526+
index = len(s)
527+
s = s[:index]
528+
# trimming docstring to be within char_lim
529+
if len(s) > char_lim:
530+
s = "{}...".format(s[:char_lim - 3])
531+
return s.strip()
532+
533+
def _extract_sklearn_parameter_docstring(self, model) -> Union[None, str]:
534+
'''Extracts the part of sklearn docstring containing parameter information
535+
536+
Fetches the entire docstring and trims just the Parameter section.
537+
The assumption is that 'Parameters' is the first section in sklearn docstrings,
538+
followed by other sections titled 'Attributes', 'See also', 'Note', 'References',
539+
appearing in that order if defined.
540+
Returns a None if no section with 'Parameters' can be found in the docstring.
541+
542+
Parameters
543+
----------
544+
model : sklearn model
545+
546+
Returns
547+
-------
548+
str, or None
549+
'''
550+
def match_format(s):
551+
return "{}\n{}\n".format(s, len(s) * '-')
552+
s = inspect.getdoc(model)
553+
if s is None:
554+
return None
555+
try:
556+
index1 = s.index(match_format("Parameters"))
557+
except ValueError as e:
558+
# when sklearn docstring has no 'Parameters' section
559+
logging.warning("{} {}".format(match_format("Parameters"), e))
560+
return None
561+
562+
headings = ["Attributes", "Notes", "See also", "Note", "References"]
563+
for h in headings:
564+
try:
565+
# to find end of Parameters section
566+
index2 = s.index(match_format(h))
567+
break
568+
except ValueError:
569+
logging.warning("{} not available in docstring".format(h))
570+
continue
571+
else:
572+
# in the case only 'Parameters' exist, trim till end of docstring
573+
index2 = len(s)
574+
s = s[index1:index2]
575+
return s.strip()
576+
577+
def _extract_sklearn_param_info(self, model, char_lim=1024) -> Union[None, Dict]:
578+
'''Parses parameter type and description from sklearn dosctring
579+
580+
Parameters
581+
----------
582+
model : sklearn model
583+
char_lim : int
584+
Specifying the max length of the returned string.
585+
OpenML servers have a constraint of 1024 characters string fields.
586+
587+
Returns
588+
-------
589+
Dict, or None
590+
'''
591+
docstring = self._extract_sklearn_parameter_docstring(model)
592+
if docstring is None:
593+
# when sklearn docstring has no 'Parameters' section
594+
return None
595+
596+
n = re.compile("[.]*\n", flags=IGNORECASE)
597+
lines = n.split(docstring)
598+
p = re.compile("[a-z0-9_ ]+ : [a-z0-9_']+[a-z0-9_ ]*", flags=IGNORECASE)
599+
# The above regular expression is designed to detect sklearn parameter names and type
600+
# in the format of [variable_name][space]:[space][type]
601+
# The expectation is that the parameter description for this detected parameter will
602+
# be all the lines in the docstring till the regex finds another parameter match
603+
604+
# collecting parameters and their descriptions
605+
description = [] # type: List
606+
for i, s in enumerate(lines):
607+
param = p.findall(s)
608+
if param != []:
609+
# a parameter definition is found by regex
610+
# creating placeholder when parameter found which will be a list of strings
611+
# string descriptions will be appended in subsequent iterations
612+
# till another parameter is found and a new placeholder is created
613+
placeholder = [''] # type: List[str]
614+
description.append(placeholder)
615+
else:
616+
if len(description) > 0: # description=[] means no parameters found yet
617+
# appending strings to the placeholder created when parameter found
618+
description[-1].append(s)
619+
for i in range(len(description)):
620+
# concatenating parameter description strings
621+
description[i] = '\n'.join(description[i]).strip()
622+
# limiting all parameter descriptions to accepted OpenML string length
623+
if len(description[i]) > char_lim:
624+
description[i] = "{}...".format(description[i][:char_lim - 3])
625+
626+
# collecting parameters and their types
627+
parameter_docs = OrderedDict() # type: Dict
628+
matches = p.findall(docstring)
629+
for i, param in enumerate(matches):
630+
key, value = str(param).split(':')
631+
parameter_docs[key.strip()] = [value.strip(), description[i]]
632+
633+
# to avoid KeyError for missing parameters
634+
param_list_true = list(model.get_params().keys())
635+
param_list_found = list(parameter_docs.keys())
636+
for param in list(set(param_list_true) - set(param_list_found)):
637+
parameter_docs[param] = [None, None]
638+
639+
return parameter_docs
640+
479641
def _serialize_model(self, model: Any) -> OpenMLFlow:
480642
"""Create an OpenMLFlow.
481643
@@ -534,10 +696,12 @@ def _serialize_model(self, model: Any) -> OpenMLFlow:
534696

535697
sklearn_version = self._format_external_version('sklearn', sklearn.__version__)
536698
sklearn_version_formatted = sklearn_version.replace('==', '_')
699+
700+
sklearn_description = self._get_sklearn_description(model)
537701
flow = OpenMLFlow(name=name,
538702
class_name=class_name,
539703
custom_name=short_name,
540-
description='Automatically created scikit-learn flow.',
704+
description=sklearn_description,
541705
model=model,
542706
components=subcomponents,
543707
parameters=parameters,
@@ -623,6 +787,7 @@ def _extract_information_from_model(
623787
sub_components_explicit = set()
624788
parameters = OrderedDict() # type: OrderedDict[str, Optional[str]]
625789
parameters_meta_info = OrderedDict() # type: OrderedDict[str, Optional[Dict]]
790+
parameters_docs = self._extract_sklearn_param_info(model)
626791

627792
model_parameters = model.get_params(deep=False)
628793
for k, v in sorted(model_parameters.items(), key=lambda t: t[0]):
@@ -743,7 +908,12 @@ def flatten_all(list_):
743908
else:
744909
parameters[k] = None
745910

746-
parameters_meta_info[k] = OrderedDict((('description', None), ('data_type', None)))
911+
if parameters_docs is not None:
912+
data_type, description = parameters_docs[k]
913+
parameters_meta_info[k] = OrderedDict((('description', description),
914+
('data_type', data_type)))
915+
else:
916+
parameters_meta_info[k] = OrderedDict((('description', None), ('data_type', None)))
747917

748918
return parameters, parameters_meta_info, sub_components, sub_components_explicit
749919

openml/flows/functions.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ def _check_flow_for_server_id(flow: OpenMLFlow) -> None:
308308
def assert_flows_equal(flow1: OpenMLFlow, flow2: OpenMLFlow,
309309
ignore_parameter_values_on_older_children: str = None,
310310
ignore_parameter_values: bool = False,
311-
ignore_custom_name_if_none: bool = False) -> None:
311+
ignore_custom_name_if_none: bool = False,
312+
check_description: bool = True) -> None:
312313
"""Check equality of two flows.
313314
314315
Two flows are equal if their all keys which are not set by the server
@@ -327,8 +328,11 @@ def assert_flows_equal(flow1: OpenMLFlow, flow2: OpenMLFlow,
327328
ignore_parameter_values : bool
328329
Whether to ignore parameter values when comparing flows.
329330
330-
ignore_custom_name_if_none : bool
331+
ignore_custom_name_if_none : bool
331332
Whether to ignore the custom name field if either flow has `custom_name` equal to `None`.
333+
334+
check_description : bool
335+
Whether to ignore matching of flow descriptions.
332336
"""
333337
if not isinstance(flow1, OpenMLFlow):
334338
raise TypeError('Argument 1 must be of type OpenMLFlow, but is %s' %
@@ -366,6 +370,10 @@ def assert_flows_equal(flow1: OpenMLFlow, flow2: OpenMLFlow,
366370
ignore_custom_name_if_none)
367371
elif key == '_extension':
368372
continue
373+
elif check_description and key == 'description':
374+
# to ignore matching of descriptions since sklearn based flows may have
375+
# altering docstrings and is not guaranteed to be consistent
376+
continue
369377
else:
370378
if key == 'parameters':
371379
if ignore_parameter_values or \
@@ -397,6 +405,35 @@ def assert_flows_equal(flow1: OpenMLFlow, flow2: OpenMLFlow,
397405
# Helps with backwards compatibility as `custom_name` is now auto-generated, but
398406
# before it used to be `None`.
399407
continue
408+
elif key == 'parameters_meta_info':
409+
# this value is a dictionary where each key is a parameter name, containing another
410+
# dictionary with keys specifying the parameter's 'description' and 'data_type'
411+
# checking parameter descriptions can be ignored since that might change
412+
# data type check can also be ignored if one of them is not defined, i.e., None
413+
params1 = set(flow1.parameters_meta_info.keys())
414+
params2 = set(flow2.parameters_meta_info.keys())
415+
if params1 != params2:
416+
raise ValueError('Parameter list in meta info for parameters differ '
417+
'in the two flows.')
418+
# iterating over the parameter's meta info list
419+
for param in params1:
420+
if isinstance(flow1.parameters_meta_info[param], Dict) and \
421+
isinstance(flow2.parameters_meta_info[param], Dict) and \
422+
'data_type' in flow1.parameters_meta_info[param] and \
423+
'data_type' in flow2.parameters_meta_info[param]:
424+
value1 = flow1.parameters_meta_info[param]['data_type']
425+
value2 = flow2.parameters_meta_info[param]['data_type']
426+
else:
427+
value1 = flow1.parameters_meta_info[param]
428+
value2 = flow2.parameters_meta_info[param]
429+
if value1 is None or value2 is None:
430+
continue
431+
elif value1 != value2:
432+
raise ValueError("Flow {}: data type for parameter {} in {} differ "
433+
"as {}\nvs\n{}".format(flow1.name, param, key,
434+
value1, value2))
435+
# the continue is to avoid the 'attr != attr2' check at end of function
436+
continue
400437

401438
if attr1 != attr2:
402439
raise ValueError("Flow %s: values for attribute '%s' differ: "

0 commit comments

Comments
 (0)