Skip to content

Commit 0b5137f

Browse files
committed
Extracting parameter type and descriptions
1 parent f1919e1 commit 0b5137f

1 file changed

Lines changed: 43 additions & 1 deletion

File tree

openml/extensions/sklearn/extension.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
import re
9+
from re import IGNORECASE
910
import sys
1011
import time
1112
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -492,6 +493,8 @@ def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
492493
def match_format(s):
493494
return "{}\n{}\n".format(s, len(s) * '-')
494495
s1 = "Parameters"
496+
# p = re.compile("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*", flags=IGNORECASE)
497+
# t = p.findall(d)
495498
# s2 = "Attributes"
496499
# s3 = "See also"
497500
# s4 = "Notes"
@@ -633,6 +636,42 @@ def _check_multiple_occurence_of_component_in_flow(
633636
known_sub_components.add(visitee.name)
634637
to_visit_stack.extend(visitee.components.values())
635638

639+
def _extract_sklearn_param_info(self, model):
640+
def match_format(s):
641+
return "{}\n{}\n".format(s, len(s) * '-')
642+
s1 = "Parameters"
643+
s2 = "Attributes"
644+
s = inspect.getdoc(model)
645+
index1 = s.index(match_format(s1))
646+
index2 = s.index(match_format(s2))
647+
docstring = s[index1:index2]
648+
n = re.compile("[.]*\n", flags=IGNORECASE)
649+
lines = n.split(docstring)
650+
p = re.compile("[a-z0-9_ ]+ : [a-z0-9_]+[a-z0-9_ ]*", flags=IGNORECASE)
651+
parameter_docs = OrderedDict()
652+
description = []
653+
654+
# collecting parameters and their descriptions
655+
for i, s in enumerate(lines):
656+
param = p.findall(s)
657+
if param != []:
658+
if len(description) > 0:
659+
description[-1] = '\n'.join(description[-1])
660+
description.append([])
661+
else:
662+
if len(description) > 0:
663+
description[-1].append(s)
664+
description[-1] = '\n'.join(description[-1])
665+
666+
# collecting parameters and their types
667+
matches = p.findall(docstring)
668+
parameter_docs = OrderedDict()
669+
for i, param in enumerate(matches):
670+
key, value = param.split(':')
671+
parameter_docs[key.strip()] = [value.strip(), description[i]]
672+
673+
return parameter_docs
674+
636675
def _extract_information_from_model(
637676
self,
638677
model: Any,
@@ -654,6 +693,7 @@ def _extract_information_from_model(
654693
sub_components_explicit = set()
655694
parameters = OrderedDict() # type: OrderedDict[str, Optional[str]]
656695
parameters_meta_info = OrderedDict() # type: OrderedDict[str, Optional[Dict]]
696+
parameters_docs = self._extract_sklearn_param_info(model)
657697

658698
model_parameters = model.get_params(deep=False)
659699
for k, v in sorted(model_parameters.items(), key=lambda t: t[0]):
@@ -774,7 +814,9 @@ def flatten_all(list_):
774814
else:
775815
parameters[k] = None
776816

777-
parameters_meta_info[k] = OrderedDict((('description', None), ('data_type', None)))
817+
data_type, description = parameters_docs[k]
818+
parameters_meta_info[k] = OrderedDict((('description', description),
819+
('data_type', data_type)))
778820

779821
return parameters, parameters_meta_info, sub_components, sub_components_explicit
780822

0 commit comments

Comments
 (0)