@@ -492,17 +492,15 @@ def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
492492 '''
493493 def match_format (s ):
494494 return "{}\n {}\n " .format (s , len (s ) * '-' )
495- s1 = "Parameters"
496- # p = re.compile("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*", flags=IGNORECASE)
497- # t = p.findall(d)
498- # s2 = "Attributes"
499- # s3 = "See also"
500- # s4 = "Notes"
501495 s = inspect .getdoc (model )
502496 if len (s ) <= char_lim :
503497 return s
504- index = s .index (match_format (s1 ))
505- # captures description till start of 'Parameters\n----------\n', excluding it
498+ try :
499+ pattern = "Read more in the :ref:" # "Parameters"
500+ index = s .index (pattern )
501+ except ValueError :
502+ pattern = "Parameters"
503+ index = s .index (match_format (pattern ))
506504 s = s [:index ]
507505 if len (s ) > char_lim :
508506 s = "{}..." .format (s [:char_lim - 3 ])
@@ -636,15 +634,33 @@ def _check_multiple_occurence_of_component_in_flow(
636634 known_sub_components .add (visitee .name )
637635 to_visit_stack .extend (visitee .components .values ())
638636
639- def _extract_sklearn_param_info (self , model ):
637+ def _extract_sklearn_parameter_docstring (self , model ):
640638 def match_format (s ):
641639 return "{}\n {}\n " .format (s , len (s ) * '-' )
642- s1 = "Parameters"
643- s2 = "Attributes"
644640 s = inspect .getdoc (model )
645- index1 = s .index (match_format (s1 ))
646- index2 = s .index (match_format (s2 ))
647- docstring = s [index1 :index2 ]
641+ s1 = "Parameters"
642+ s2 = ["Attributes" , "See also" , "Note" , "References" ]
643+ try :
644+ index1 = s .index (match_format (s1 ))
645+ except ValueError as e :
646+ print ("Parameter {}" .format (e ))
647+ # returns the whole sklearn docstring available
648+ return s
649+ for h in s2 :
650+ try :
651+ index2 = s .index (match_format (h ))
652+ break
653+ except ValueError :
654+ print ("{} not available in docstring" .format (h ))
655+ continue
656+ else :
657+ # in the case only 'Parameters' exist
658+ index2 = len (s )
659+ s = s [index1 :index2 ]
660+ return s
661+
662+ def _extract_sklearn_param_info (self , model ):
663+ docstring = self ._extract_sklearn_parameter_docstring (model )
648664 n = re .compile ("[.]*\n " , flags = IGNORECASE )
649665 lines = n .split (docstring )
650666 p = re .compile ("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*" , flags = IGNORECASE )
@@ -656,12 +672,12 @@ def match_format(s):
656672 param = p .findall (s )
657673 if param != []:
658674 if len (description ) > 0 :
659- description [- 1 ] = '\n ' .join (description [- 1 ])
675+ description [- 1 ] = '\n ' .join (description [- 1 ]). strip ()
660676 description .append ([])
661677 else :
662678 if len (description ) > 0 :
663679 description [- 1 ].append (s )
664- description [- 1 ] = '\n ' .join (description [- 1 ])
680+ description [- 1 ] = '\n ' .join (description [- 1 ]). strip ()
665681
666682 # collecting parameters and their types
667683 matches = p .findall (docstring )
0 commit comments