@@ -480,31 +480,48 @@ def _is_sklearn_flow(cls, flow: OpenMLFlow) -> bool:
480480 def _get_sklearn_description (self , model : Any , char_lim : int = 1024 ) -> str :
481481 '''Fetches the sklearn function docstring for the flow description
482482
483+ Retrieves the sklearn docstring available and does the following:
484+ * If length of docstring <= char_lim, then returns the complete docstring
485+ * Else, trims the docstring till it encounters a 'Read more in the :ref:'
486+ * Or till it encounters a 'Parameters\n ----------\n '
487+ The final string returned is at most of length char_lim with leading and
488+ trailing whitespaces removed.
489+
483490 Parameters
484491 ----------
485- model: The sklearn model object
486- char_lim: int, specifying the max length of the returned string
492+ model : sklearn model
493+ char_lim : int
494+ Specifying the max length of the returned string
487495 OpenML servers have a constraint of 1024 characters for the 'description' field.
488496
489497 Returns
490498 -------
491- string of length <= char_lim
499+ str
492500 '''
493501 def match_format (s ):
494502 return "{}\n {}\n " .format (s , len (s ) * '-' )
495503 s = inspect .getdoc (model )
496504 if len (s ) <= char_lim :
497- return s
505+ # if the fetched docstring is smaller than char_lim, no trimming required
506+ return s .strip ()
498507 try :
499- pattern = "Read more in the :ref:" # "Parameters"
508+ # trim till 'Read more'
509+ pattern = "Read more in the :ref:"
500510 index = s .index (pattern )
501511 except ValueError :
512+ pass
513+ try :
514+ # if 'Read more' doesn't exist, trim till 'Parameters'
502515 pattern = "Parameters"
503516 index = s .index (match_format (pattern ))
517+ except ValueError :
518+ # returning full docstring
519+ index = len (s )
504520 s = s [:index ]
521+ # trimming docstring to be within char_lim
505522 if len (s ) > char_lim :
506523 s = "{}..." .format (s [:char_lim - 3 ])
507- return s
524+ return s . strip ()
508525
509526 def _serialize_model (self , model : Any ) -> OpenMLFlow :
510527 """Create an OpenMLFlow.
@@ -634,38 +651,69 @@ def _check_multiple_occurence_of_component_in_flow(
634651 known_sub_components .add (visitee .name )
635652 to_visit_stack .extend (visitee .components .values ())
636653
637- def _extract_sklearn_parameter_docstring (self , model ):
654+ def _extract_sklearn_parameter_docstring (self , model ) -> Union [None , str ]:
655+ '''Extracts the part of sklearn docstring containing parameter information
656+
657+ Fetches the entire docstring and trims just the Parameter section.
658+ The assumption is that 'Parameters' is the first section in sklearn docstrings,
659+ followed by other sections titled 'Attributes', 'See also', 'Note', 'References',
660+ appearing in that order if defined.
661+ Returns a None if no section with 'Parameters' can be found in the docstring.
662+
663+ Parameters
664+ ----------
665+ model : sklearn model
666+
667+ Returns
668+ -------
669+ str, or None
670+ '''
638671 def match_format (s ):
639672 return "{}\n {}\n " .format (s , len (s ) * '-' )
640673 s = inspect .getdoc (model )
641- s1 = "Parameters"
642- s2 = ["Attributes" , "See also" , "Note" , "References" ]
643674 try :
644- index1 = s .index (match_format (s1 ))
675+ index1 = s .index (match_format ("Parameters" ))
645676 except ValueError as e :
646- print ("Parameter {}" .format (e ))
647- # returns the whole sklearn docstring available
648- return s
649- for h in s2 :
677+ # when sklearn docstring has no 'Parameters' section
678+ print ("{} {}" .format (match_format ("Parameters" ), e ))
679+ return None
680+
681+ headings = ["Attributes" , "See also" , "Note" , "References" ]
682+ for h in headings :
650683 try :
684+ # to find end of Parameters section
651685 index2 = s .index (match_format (h ))
652686 break
653687 except ValueError :
654688 print ("{} not available in docstring" .format (h ))
655689 continue
656690 else :
657- # in the case only 'Parameters' exist
691+ # in the case only 'Parameters' exist, trim till end of docstring
658692 index2 = len (s )
659693 s = s [index1 :index2 ]
660- return s
694+ return s .strip ()
695+
696+ def _extract_sklearn_param_info (self , model ) -> Union [None , Dict ]:
697+ '''Parses parameter type and description from sklearn dosctring
698+
699+ Parameters
700+ ----------
701+ model : sklearn model
661702
662- def _extract_sklearn_param_info (self , model ):
703+ Returns
704+ -------
705+ Dict, or None
706+ '''
663707 docstring = self ._extract_sklearn_parameter_docstring (model )
708+ if docstring is None :
709+ # when sklearn docstring has no 'Parameters' section
710+ return None
711+
664712 n = re .compile ("[.]*\n " , flags = IGNORECASE )
665713 lines = n .split (docstring )
666714 p = re .compile ("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*" , flags = IGNORECASE )
667- parameter_docs = OrderedDict ()
668- description = []
715+ parameter_docs = OrderedDict () # type: Dict
716+ description = [] # type: List
669717
670718 # collecting parameters and their descriptions
671719 for i , s in enumerate (lines ):
@@ -681,7 +729,6 @@ def _extract_sklearn_param_info(self, model):
681729
682730 # collecting parameters and their types
683731 matches = p .findall (docstring )
684- parameter_docs = OrderedDict ()
685732 for i , param in enumerate (matches ):
686733 key , value = param .split (':' )
687734 parameter_docs [key .strip ()] = [value .strip (), description [i ]]
@@ -830,9 +877,12 @@ def flatten_all(list_):
830877 else :
831878 parameters [k ] = None
832879
833- data_type , description = parameters_docs [k ]
834- parameters_meta_info [k ] = OrderedDict ((('description' , description ),
835- ('data_type' , data_type )))
880+ if parameters_docs is not None :
881+ data_type , description = parameters_docs [k ]
882+ parameters_meta_info [k ] = OrderedDict ((('description' , description ),
883+ ('data_type' , data_type )))
884+ else :
885+ parameters_meta_info [k ] = OrderedDict ((('description' , None ), ('data_type' , None )))
836886
837887 return parameters , parameters_meta_info , sub_components , sub_components_explicit
838888
0 commit comments