|
6 | 6 | import json |
7 | 7 | import logging |
8 | 8 | import re |
| 9 | +from re import IGNORECASE |
9 | 10 | import sys |
10 | 11 | import time |
11 | 12 | from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
@@ -476,6 +477,167 @@ def _is_sklearn_flow(cls, flow: OpenMLFlow) -> bool: |
476 | 477 | or ',sklearn==' in flow.external_version |
477 | 478 | ) |
478 | 479 |
|
| 480 | + def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str: |
| 481 | + '''Fetches the sklearn function docstring for the flow description |
| 482 | +
|
| 483 | + Retrieves the sklearn docstring available and does the following: |
| 484 | + * If length of docstring <= char_lim, then returns the complete docstring |
| 485 | + * Else, trims the docstring till it encounters a 'Read more in the :ref:' |
| 486 | + * Or till it encounters a 'Parameters\n----------\n' |
| 487 | + The final string returned is at most of length char_lim with leading and |
| 488 | + trailing whitespaces removed. |
| 489 | +
|
| 490 | + Parameters |
| 491 | + ---------- |
| 492 | + model : sklearn model |
| 493 | + char_lim : int |
| 494 | + Specifying the max length of the returned string. |
| 495 | + OpenML servers have a constraint of 1024 characters for the 'description' field. |
| 496 | +
|
| 497 | + Returns |
| 498 | + ------- |
| 499 | + str |
| 500 | + ''' |
| 501 | + def match_format(s): |
| 502 | + return "{}\n{}\n".format(s, len(s) * '-') |
| 503 | + s = inspect.getdoc(model) |
| 504 | + if s is None: |
| 505 | + return '' |
| 506 | + try: |
| 507 | + # trim till 'Read more' |
| 508 | + pattern = "Read more in the :ref:" |
| 509 | + index = s.index(pattern) |
| 510 | + s = s[:index] |
| 511 | + # trimming docstring to be within char_lim |
| 512 | + if len(s) > char_lim: |
| 513 | + s = "{}...".format(s[:char_lim - 3]) |
| 514 | + return s.strip() |
| 515 | + except ValueError: |
| 516 | + logging.warning("'Read more' not found in descriptions. " |
| 517 | + "Trying to trim till 'Parameters' if available in docstring.") |
| 518 | + pass |
| 519 | + try: |
| 520 | + # if 'Read more' doesn't exist, trim till 'Parameters' |
| 521 | + pattern = "Parameters" |
| 522 | + index = s.index(match_format(pattern)) |
| 523 | + except ValueError: |
| 524 | + # returning full docstring |
| 525 | + logging.warning("'Parameters' not found in docstring. Omitting docstring trimming.") |
| 526 | + index = len(s) |
| 527 | + s = s[:index] |
| 528 | + # trimming docstring to be within char_lim |
| 529 | + if len(s) > char_lim: |
| 530 | + s = "{}...".format(s[:char_lim - 3]) |
| 531 | + return s.strip() |
| 532 | + |
| 533 | + def _extract_sklearn_parameter_docstring(self, model) -> Union[None, str]: |
| 534 | + '''Extracts the part of sklearn docstring containing parameter information |
| 535 | +
|
| 536 | + Fetches the entire docstring and trims just the Parameter section. |
| 537 | + The assumption is that 'Parameters' is the first section in sklearn docstrings, |
| 538 | + followed by other sections titled 'Attributes', 'See also', 'Note', 'References', |
| 539 | + appearing in that order if defined. |
| 540 | + Returns a None if no section with 'Parameters' can be found in the docstring. |
| 541 | +
|
| 542 | + Parameters |
| 543 | + ---------- |
| 544 | + model : sklearn model |
| 545 | +
|
| 546 | + Returns |
| 547 | + ------- |
| 548 | + str, or None |
| 549 | + ''' |
| 550 | + def match_format(s): |
| 551 | + return "{}\n{}\n".format(s, len(s) * '-') |
| 552 | + s = inspect.getdoc(model) |
| 553 | + if s is None: |
| 554 | + return None |
| 555 | + try: |
| 556 | + index1 = s.index(match_format("Parameters")) |
| 557 | + except ValueError as e: |
| 558 | + # when sklearn docstring has no 'Parameters' section |
| 559 | + logging.warning("{} {}".format(match_format("Parameters"), e)) |
| 560 | + return None |
| 561 | + |
| 562 | + headings = ["Attributes", "Notes", "See also", "Note", "References"] |
| 563 | + for h in headings: |
| 564 | + try: |
| 565 | + # to find end of Parameters section |
| 566 | + index2 = s.index(match_format(h)) |
| 567 | + break |
| 568 | + except ValueError: |
| 569 | + logging.warning("{} not available in docstring".format(h)) |
| 570 | + continue |
| 571 | + else: |
| 572 | + # in the case only 'Parameters' exist, trim till end of docstring |
| 573 | + index2 = len(s) |
| 574 | + s = s[index1:index2] |
| 575 | + return s.strip() |
| 576 | + |
| 577 | + def _extract_sklearn_param_info(self, model, char_lim=1024) -> Union[None, Dict]: |
| 578 | + '''Parses parameter type and description from sklearn dosctring |
| 579 | +
|
| 580 | + Parameters |
| 581 | + ---------- |
| 582 | + model : sklearn model |
| 583 | + char_lim : int |
| 584 | + Specifying the max length of the returned string. |
| 585 | + OpenML servers have a constraint of 1024 characters string fields. |
| 586 | +
|
| 587 | + Returns |
| 588 | + ------- |
| 589 | + Dict, or None |
| 590 | + ''' |
| 591 | + docstring = self._extract_sklearn_parameter_docstring(model) |
| 592 | + if docstring is None: |
| 593 | + # when sklearn docstring has no 'Parameters' section |
| 594 | + return None |
| 595 | + |
| 596 | + n = re.compile("[.]*\n", flags=IGNORECASE) |
| 597 | + lines = n.split(docstring) |
| 598 | + p = re.compile("[a-z0-9_ ]+ : [a-z0-9_']+[a-z0-9_ ]*", flags=IGNORECASE) |
| 599 | + # The above regular expression is designed to detect sklearn parameter names and type |
| 600 | + # in the format of [variable_name][space]:[space][type] |
| 601 | + # The expectation is that the parameter description for this detected parameter will |
| 602 | + # be all the lines in the docstring till the regex finds another parameter match |
| 603 | + |
| 604 | + # collecting parameters and their descriptions |
| 605 | + description = [] # type: List |
| 606 | + for i, s in enumerate(lines): |
| 607 | + param = p.findall(s) |
| 608 | + if param != []: |
| 609 | + # a parameter definition is found by regex |
| 610 | + # creating placeholder when parameter found which will be a list of strings |
| 611 | + # string descriptions will be appended in subsequent iterations |
| 612 | + # till another parameter is found and a new placeholder is created |
| 613 | + placeholder = [''] # type: List[str] |
| 614 | + description.append(placeholder) |
| 615 | + else: |
| 616 | + if len(description) > 0: # description=[] means no parameters found yet |
| 617 | + # appending strings to the placeholder created when parameter found |
| 618 | + description[-1].append(s) |
| 619 | + for i in range(len(description)): |
| 620 | + # concatenating parameter description strings |
| 621 | + description[i] = '\n'.join(description[i]).strip() |
| 622 | + # limiting all parameter descriptions to accepted OpenML string length |
| 623 | + if len(description[i]) > char_lim: |
| 624 | + description[i] = "{}...".format(description[i][:char_lim - 3]) |
| 625 | + |
| 626 | + # collecting parameters and their types |
| 627 | + parameter_docs = OrderedDict() # type: Dict |
| 628 | + matches = p.findall(docstring) |
| 629 | + for i, param in enumerate(matches): |
| 630 | + key, value = str(param).split(':') |
| 631 | + parameter_docs[key.strip()] = [value.strip(), description[i]] |
| 632 | + |
| 633 | + # to avoid KeyError for missing parameters |
| 634 | + param_list_true = list(model.get_params().keys()) |
| 635 | + param_list_found = list(parameter_docs.keys()) |
| 636 | + for param in list(set(param_list_true) - set(param_list_found)): |
| 637 | + parameter_docs[param] = [None, None] |
| 638 | + |
| 639 | + return parameter_docs |
| 640 | + |
479 | 641 | def _serialize_model(self, model: Any) -> OpenMLFlow: |
480 | 642 | """Create an OpenMLFlow. |
481 | 643 |
|
@@ -534,10 +696,12 @@ def _serialize_model(self, model: Any) -> OpenMLFlow: |
534 | 696 |
|
535 | 697 | sklearn_version = self._format_external_version('sklearn', sklearn.__version__) |
536 | 698 | sklearn_version_formatted = sklearn_version.replace('==', '_') |
| 699 | + |
| 700 | + sklearn_description = self._get_sklearn_description(model) |
537 | 701 | flow = OpenMLFlow(name=name, |
538 | 702 | class_name=class_name, |
539 | 703 | custom_name=short_name, |
540 | | - description='Automatically created scikit-learn flow.', |
| 704 | + description=sklearn_description, |
541 | 705 | model=model, |
542 | 706 | components=subcomponents, |
543 | 707 | parameters=parameters, |
@@ -623,6 +787,7 @@ def _extract_information_from_model( |
623 | 787 | sub_components_explicit = set() |
624 | 788 | parameters = OrderedDict() # type: OrderedDict[str, Optional[str]] |
625 | 789 | parameters_meta_info = OrderedDict() # type: OrderedDict[str, Optional[Dict]] |
| 790 | + parameters_docs = self._extract_sklearn_param_info(model) |
626 | 791 |
|
627 | 792 | model_parameters = model.get_params(deep=False) |
628 | 793 | for k, v in sorted(model_parameters.items(), key=lambda t: t[0]): |
@@ -743,7 +908,12 @@ def flatten_all(list_): |
743 | 908 | else: |
744 | 909 | parameters[k] = None |
745 | 910 |
|
746 | | - parameters_meta_info[k] = OrderedDict((('description', None), ('data_type', None))) |
| 911 | + if parameters_docs is not None: |
| 912 | + data_type, description = parameters_docs[k] |
| 913 | + parameters_meta_info[k] = OrderedDict((('description', description), |
| 914 | + ('data_type', data_type))) |
| 915 | + else: |
| 916 | + parameters_meta_info[k] = OrderedDict((('description', None), ('data_type', None))) |
747 | 917 |
|
748 | 918 | return parameters, parameters_meta_info, sub_components, sub_components_explicit |
749 | 919 |
|
|
0 commit comments