-
Notifications
You must be signed in to change notification settings - Fork 62
automatic generation of type checks for overload #911
base: numba_typing
Are you sure you want to change the base?
Changes from 4 commits
8bd7bc0
d3f4a5d
b54da17
05c745b
b7446ca
09289bf
1cb60da
967ae29
716402c
2096e94
5ec33ac
7da564b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| import numpy | ||
| import numba | ||
| from numba import types | ||
| from numba import typeof | ||
| from numba.extending import overload | ||
| from type_annotations import product_annotations, get_func_annotations | ||
| from numba import njit | ||
| import typing | ||
| from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning | ||
| import warnings | ||
| from numba.typed import List, Dict | ||
| from inspect import getfullargspec | ||
|
|
||
|
|
||
| warnings.simplefilter('ignore', category=NumbaDeprecationWarning) | ||
| warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? |
||
|
|
||
|
|
||
| def overload_list(orig_func): | ||
| def overload_inner(ovld_list): | ||
| def wrapper(*args): | ||
| func_list = ovld_list() | ||
| sig_list = [] | ||
| for func in func_list: | ||
| sig_list.append((product_annotations( | ||
| get_func_annotations(func)), func)) | ||
| args_orig_func = getfullargspec(orig_func) | ||
| values_dict = {name: typ for name, typ in zip(args_orig_func.args, args)} | ||
| defaults_dict = {} | ||
| if args_orig_func.defaults: | ||
| defaults_dict = {name: value for name, value in zip( | ||
| args_orig_func.args[::-1], args_orig_func.defaults[::-1])} | ||
| result = choose_func_by_sig(sig_list, values_dict, defaults_dict) | ||
|
|
||
| if result is None: | ||
| raise numba.TypingError(f'Unsupported types a={a}, b={b}') | ||
|
|
||
| return result | ||
|
|
||
| return overload(orig_func, strict=False)(wrapper) | ||
|
|
||
| return overload_inner | ||
|
|
||
|
|
||
| def check_int_type(n_type): | ||
| return isinstance(n_type, types.Integer) | ||
|
|
||
|
|
||
| def check_float_type(n_type): | ||
| return isinstance(n_type, types.Float) | ||
|
|
||
|
|
||
| def check_bool_type(n_type): | ||
| return isinstance(n_type, types.Boolean) | ||
|
|
||
|
|
||
| def check_str_type(n_type): | ||
| return isinstance(n_type, types.UnicodeType) | ||
|
|
||
|
|
||
| def check_list_type(self, p_type, n_type): | ||
| res = isinstance(n_type, types.List) or isinstance(n_type, types.ListType) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. res = isinstance(n_type, (types.List, types.ListType)) |
||
| if isinstance(p_type, type): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isinstance(p_type, (list, typing.List))? |
||
| return res | ||
| else: | ||
| return res and self.match(p_type.__args__[0], n_type.dtype) | ||
|
|
||
|
|
||
| def check_tuple_type(self, p_type, n_type): | ||
| res = False | ||
| if isinstance(n_type, types.Tuple): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of using Something like: if not isinstance(n_type, types.Tuple, types.UniTuple):
return False
for p_val, n_val in zip(p_type.__args__, n_type.types):
if not self.match(p_val, n_val):
return False
return TrueAnd btw you need to check that size of |
||
| res = True | ||
| if isinstance(p_type, type): | ||
| return res | ||
| for p_val, n_val in zip(p_type.__args__, n_type.key): | ||
| res = res and self.match(p_val, n_val) | ||
| if isinstance(n_type, types.UniTuple): | ||
| res = True | ||
| if isinstance(p_type, type): | ||
| return res | ||
| for p_val in p_type.__args__: | ||
| res = res and self.match(p_val, n_type.key[0]) | ||
| return res | ||
|
|
||
|
|
||
| def check_dict_type(self, p_type, n_type): | ||
| res = False | ||
| if isinstance(n_type, types.DictType): | ||
| res = True | ||
| if isinstance(p_type, type): | ||
| return res | ||
| for p_val, n_val in zip(p_type.__args__, n_type.keyvalue_type): | ||
| res = res and self.match(p_val, n_val) | ||
| return res | ||
|
|
||
|
|
||
| class TypeChecker: | ||
|
|
||
| _types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer this checks to be added using |
||
| str: check_str_type, list: check_list_type, | ||
| tuple: check_tuple_type, dict: check_dict_type} | ||
|
|
||
| def __init__(self): | ||
| self._typevars_dict = {} | ||
|
|
||
| def clear_typevars_dict(self): | ||
| self._typevars_dict.clear() | ||
|
|
||
| def add_type_check(self, type_check, func): | ||
| self._types_dict[type_check] = func | ||
|
|
||
| def _is_generic(self, p_obj): | ||
| if isinstance(p_obj, typing._GenericAlias): | ||
| return True | ||
|
|
||
| if isinstance(p_obj, typing._SpecialForm): | ||
| return p_obj not in {typing.Any} | ||
|
|
||
| return False | ||
|
|
||
| def _get_origin(self, p_obj): | ||
| return p_obj.__origin__ | ||
|
|
||
| def match(self, p_type, n_type): | ||
| try: | ||
| if p_type == typing.Any: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's do it like this: if p_type == typing.Any:
return True
if self._is_generic(p_type):
origin_type = self._get_origin(p_type)
if origin_type == typing.Generic:
return self.match_generic(p_type, n_type)
return self._types_dict[origin_type](self, p_type, n_type)
if isinstance(p_type, typing.TypeVar):
return self.match_typevar(p_type, n_type)
if p_type in (list, tuple):
return self._types_dict[p_type](self, p_type, n_type)
return self._types_dict[p_type](n_type) |
||
| return True | ||
| elif self._is_generic(p_type): | ||
| origin_type = self._get_origin(p_type) | ||
| if origin_type == typing.Generic: | ||
| return self.match_generic(p_type, n_type) | ||
| else: | ||
| return self._types_dict[origin_type](self, p_type, n_type) | ||
| elif isinstance(p_type, typing.TypeVar): | ||
| return self.match_typevar(p_type, n_type) | ||
| else: | ||
| if p_type in (list, tuple): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't dict be here too? |
||
| return self._types_dict[p_type](self, p_type, n_type) | ||
| return self._types_dict[p_type](n_type) | ||
| except KeyError: | ||
| print((f'A check for the {p_type} was not found. {n_type}')) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we will (implicitly) return |
||
|
|
||
| def match_typevar(self, p_type, n_type): | ||
| if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need condition |
||
| self._typevars_dict[p_type] = n_type | ||
| return True | ||
| return self._typevars_dict.get(p_type) == n_type | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it should be |
||
|
|
||
| def match_generic(self, p_type, n_type): | ||
| res = True | ||
| for arg in p_type.__args__: | ||
| res = res and self.match(arg, n_type) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's doesn't feel right. Do we have any test for this case? |
||
| return res | ||
|
|
||
|
|
||
| def choose_func_by_sig(sig_list, values_dict, defaults_dict={}): | ||
| checker = TypeChecker() | ||
| for sig in sig_list: # sig = (Signature,func) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of accessing |
||
| for param in sig[0].parameters: # param = {'a':int,'b':int} | ||
| full_match = True | ||
| for name, typ in values_dict.items(): # name,type = 'a',int64 | ||
| if isinstance(typ, types.Literal): | ||
|
|
||
| full_match = full_match and checker.match( | ||
| param[name], typ.literal_type) | ||
|
|
||
| if sig[0].defaults.get(name, False): | ||
| full_match = full_match and sig[0].defaults[name] == typ.literal_value | ||
| else: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this |
||
| full_match = full_match and checker.match(param[name], typ) | ||
|
|
||
| if not full_match: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need this |
||
| break | ||
|
|
||
| for name, val in defaults_dict.items(): | ||
| if sig[0].defaults.get(name) != None: | ||
| full_match = full_match and sig[0].defaults[name] == val | ||
|
|
||
| checker.clear_typevars_dict() | ||
| if full_match: | ||
| return sig[1] | ||
|
|
||
| return None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some unused imports here, like
numpy,typeofandnjit, please clean them up.