66import json
77import logging
88import re
9+ from re import IGNORECASE
910import sys
1011import time
1112from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
@@ -492,6 +493,8 @@ def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
492493 def match_format (s ):
493494 return "{}\n {}\n " .format (s , len (s ) * '-' )
494495 s1 = "Parameters"
496+ # p = re.compile("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*", flags=IGNORECASE)
497+ # t = p.findall(d)
495498 # s2 = "Attributes"
496499 # s3 = "See also"
497500 # s4 = "Notes"
@@ -633,6 +636,42 @@ def _check_multiple_occurence_of_component_in_flow(
633636 known_sub_components .add (visitee .name )
634637 to_visit_stack .extend (visitee .components .values ())
635638
639+ def _extract_sklearn_param_info (self , model ):
640+ def match_format (s ):
641+ return "{}\n {}\n " .format (s , len (s ) * '-' )
642+ s1 = "Parameters"
643+ s2 = "Attributes"
644+ s = inspect .getdoc (model )
645+ index1 = s .index (match_format (s1 ))
646+ index2 = s .index (match_format (s2 ))
647+ docstring = s [index1 :index2 ]
648+ n = re .compile ("[.]*\n " , flags = IGNORECASE )
649+ lines = n .split (docstring )
650+ p = re .compile ("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*" , flags = IGNORECASE )
651+ parameter_docs = OrderedDict ()
652+ description = []
653+
654+ # collecting parameters and their descriptions
655+ for i , s in enumerate (lines ):
656+ param = p .findall (s )
657+ if param != []:
658+ if len (description ) > 0 :
659+ description [- 1 ] = '\n ' .join (description [- 1 ])
660+ description .append ([])
661+ else :
662+ if len (description ) > 0 :
663+ description [- 1 ].append (s )
664+ description [- 1 ] = '\n ' .join (description [- 1 ])
665+
666+ # collecting parameters and their types
667+ matches = p .findall (docstring )
668+ parameter_docs = OrderedDict ()
669+ for i , param in enumerate (matches ):
670+ key , value = param .split (':' )
671+ parameter_docs [key .strip ()] = [value .strip (), description [i ]]
672+
673+ return parameter_docs
674+
636675 def _extract_information_from_model (
637676 self ,
638677 model : Any ,
@@ -654,6 +693,7 @@ def _extract_information_from_model(
654693 sub_components_explicit = set ()
655694 parameters = OrderedDict () # type: OrderedDict[str, Optional[str]]
656695 parameters_meta_info = OrderedDict () # type: OrderedDict[str, Optional[Dict]]
696+ parameters_docs = self ._extract_sklearn_param_info (model )
657697
658698 model_parameters = model .get_params (deep = False )
659699 for k , v in sorted (model_parameters .items (), key = lambda t : t [0 ]):
@@ -774,7 +814,9 @@ def flatten_all(list_):
774814 else :
775815 parameters [k ] = None
776816
777- parameters_meta_info [k ] = OrderedDict ((('description' , None ), ('data_type' , None )))
817+ data_type , description = parameters_docs [k ]
818+ parameters_meta_info [k ] = OrderedDict ((('description' , description ),
819+ ('data_type' , data_type )))
778820
779821 return parameters , parameters_meta_info , sub_components , sub_components_explicit
780822
0 commit comments