3838from numba import types
3939from numba .special import literally
4040from sdc .hiframes .pd_dataframe_ext import DataFrameType
41- from sdc .utilities .sdc_typing_utils import TypeChecker
41+ from sdc .hiframes .pd_series_type import SeriesType
42+ from sdc .utilities .sdc_typing_utils import (TypeChecker , check_index_is_numeric ,
43+ check_types_comparable ,
44+ gen_df_impl_generator )
4245from sdc .str_arr_ext import StringArrayType
4346
4447from sdc .hiframes .pd_dataframe_type import DataFrameType
@@ -929,6 +932,32 @@ def sdc_pandas_dataframe_drop_impl(df, _func_name, args, columns):
929932 return sdc_pandas_dataframe_drop_impl (df , _func_name , args , columns )
930933
931934
935+ def df_getitem_bool_series_idx_main_codelines (self , idx ):
936+ """Generate main code lines for df.getitem"""
937+ func_lines = [' self_length = len(get_dataframe_data(self, 0))' ,
938+ ' trimmed_idx_data = idx._data[:self_length]' ]
939+
940+ if isinstance (self .index , types .NoneType ):
941+ func_lines += [' self_index = numpy.arange(self_length)' ]
942+ else :
943+ func_lines += [' self_index = self._index' ]
944+
945+ results = []
946+ for i , col in enumerate (self .columns ):
947+ res_data = f'res_data_{ i } '
948+ func_lines += [
949+ f' data_{ i } = get_dataframe_data(self, { i } )' ,
950+ f' series = pandas.Series(data_{ i } , index=self_index, name="{ col } ")' ,
951+ f' { res_data } = series[trimmed_idx_data]' ,
952+ ]
953+ results .append ((col , res_data ))
954+
955+ data = ', ' .join (f'"{ col } ": { data } ' for col , data in results )
956+ func_lines += [f' return pandas.DataFrame({{{ data } }}, index=self_index[trimmed_idx_data])' ]
957+
958+ return func_lines
959+
960+
932961def df_index_codelines (self ):
933962 """Generate code lines to get or create index of DF"""
934963 if isinstance (self .index , types .NoneType ):
@@ -941,6 +970,11 @@ def df_index_codelines(self):
941970 return func_lines
942971
943972
973+ def df_getitem_key_error_codelines ():
974+ """Generate code lines to raise KeyError"""
975+ return [' raise KeyError("Column is not in the DataFrame")' ]
976+
977+
944978def df_getitem_slice_idx_main_codelines (self , idx ):
945979 """Generate main code lines for df.getitem with idx of slice"""
946980 results = []
@@ -978,6 +1012,35 @@ def df_getitem_tuple_idx_main_codelines(self, literal_idx):
9781012 return func_lines
9791013
9801014
1015+ def df_getitem_bool_series_codegen (self , idx ):
1016+ """
1017+ Example of generated implementation with provided index:
1018+ def _df_getitem_bool_series_idx_impl(self, idx):
1019+ self_length = len(get_dataframe_data(self, 0))
1020+ trimmed_idx_data = idx._data[:self_length]
1021+ self_index = self._index
1022+ data_0 = get_dataframe_data(self, 0)
1023+ series = pandas.Series(data_0, index=self_index, name="A")
1024+ res_data_0 = series[trimmed_idx_data]
1025+ data_1 = get_dataframe_data(self, 1)
1026+ series = pandas.Series(data_1, index=self_index, name="B")
1027+ res_data_1 = series[trimmed_idx_data]
1028+ return pandas.DataFrame({"A": res_data_0, "B": res_data_1}, index=self_index[trimmed_idx_data])
1029+ """
1030+ func_lines = ['def _df_getitem_bool_series_idx_impl(self, idx):' ]
1031+ if self .columns :
1032+ func_lines += df_getitem_bool_series_idx_main_codelines (self , idx )
1033+ else :
1034+ # raise KeyError if input DF is empty
1035+ func_lines += df_getitem_key_error_codelines ()
1036+
1037+ func_text = '\n ' .join (func_lines )
1038+ global_vars = {'pandas' : pandas , 'numpy' : numpy ,
1039+ 'get_dataframe_data' : get_dataframe_data }
1040+
1041+ return func_text , global_vars
1042+
1043+
9811044def df_getitem_slice_idx_codegen (self , idx ):
9821045 """
9831046 Example of generated implementation with provided index:
@@ -994,7 +1057,7 @@ def _df_getitem_slice_idx_impl(self, idx)
9941057 func_lines += df_getitem_slice_idx_main_codelines (self , idx )
9951058 else :
9961059 # raise KeyError if input DF is empty
997- func_lines += [ ' raise KeyError' ]
1060+ func_lines += df_getitem_key_error_codelines ()
9981061
9991062 func_text = '\n ' .join (func_lines )
10001063 global_vars = {'pandas' : pandas , 'numpy' : numpy ,
@@ -1022,7 +1085,7 @@ def _df_getitem_tuple_idx_impl(self, idx)
10221085 func_lines += df_getitem_tuple_idx_main_codelines (self , literal_idx )
10231086 else :
10241087 # raise KeyError if input DF is empty or idx is invalid
1025- func_lines += [ ' raise KeyError' ]
1088+ func_lines += df_getitem_key_error_codelines ()
10261089
10271090 func_text = '\n ' .join (func_lines )
10281091 global_vars = {'pandas' : pandas , 'numpy' : numpy ,
@@ -1031,28 +1094,17 @@ def _df_getitem_tuple_idx_impl(self, idx)
10311094 return func_text , global_vars
10321095
10331096
1034- def gen_df_getitem_impl_generator (codegen , impl_name ):
1035- """Generate generator of df.getitem"""
1036- def _df_getitem_impl_generator (self , idx ):
1037- func_text , global_vars = codegen (self , idx )
1038-
1039- loc_vars = {}
1040- exec (func_text , global_vars , loc_vars )
1041- _impl = loc_vars [impl_name ]
1042-
1043- return _impl
1044-
1045- return _df_getitem_impl_generator
1046-
1047-
1048- gen_df_getitem_slice_idx_impl = gen_df_getitem_impl_generator (
1097+ gen_df_getitem_slice_idx_impl = gen_df_impl_generator (
10491098 df_getitem_slice_idx_codegen , '_df_getitem_slice_idx_impl' )
1050- gen_df_getitem_tuple_idx_impl = gen_df_getitem_impl_generator (
1099+ gen_df_getitem_tuple_idx_impl = gen_df_impl_generator (
10511100 df_getitem_tuple_idx_codegen , '_df_getitem_tuple_idx_impl' )
1101+ gen_df_getitem_bool_series_idx_impl = gen_df_impl_generator (
1102+ df_getitem_bool_series_codegen , '_df_getitem_bool_series_idx_impl' )
10521103
10531104
10541105@sdc_overload (operator .getitem )
10551106def sdc_pandas_dataframe_getitem (self , idx ):
1107+ ty_checker = TypeChecker ('Operator getitem().' )
10561108
10571109 if not isinstance (self , DataFrameType ):
10581110 return None
@@ -1069,7 +1121,7 @@ def _df_getitem_str_literal_idx_impl(self, idx):
10691121 data = get_dataframe_data (self , col_idx )
10701122 return pandas .Series (data , index = self ._index , name = idx )
10711123 else :
1072- raise KeyError
1124+ raise KeyError ( 'Column is not in the DataFrame' )
10731125
10741126 return _df_getitem_str_literal_idx_impl
10751127
@@ -1082,12 +1134,30 @@ def _df_getitem_unicode_idx_impl(self, idx):
10821134 return _df_getitem_unicode_idx_impl
10831135
10841136 if isinstance (idx , types .Tuple ):
1085- return gen_df_getitem_tuple_idx_impl (self , idx )
1137+ if all ([isinstance (item , types .StringLiteral ) for item in idx ]):
1138+ return gen_df_getitem_tuple_idx_impl (self , idx )
10861139
10871140 if isinstance (idx , types .SliceType ):
10881141 return gen_df_getitem_slice_idx_impl (self , idx )
10891142
1090- ty_checker = TypeChecker ('Operator getitem().' )
1143+ if isinstance (idx , SeriesType ) and isinstance (idx .dtype , types .Boolean ):
1144+ self_index_is_none = isinstance (self .index , types .NoneType )
1145+ idx_index_is_none = isinstance (idx .index , types .NoneType )
1146+
1147+ if self_index_is_none and not idx_index_is_none :
1148+ if not check_index_is_numeric (idx ):
1149+ ty_checker .raise_exc (idx .index .dtype , 'number' , 'idx.index.dtype' )
1150+
1151+ if not self_index_is_none and idx_index_is_none :
1152+ if not check_index_is_numeric (self ):
1153+ ty_checker .raise_exc (idx .index .dtype , self .index .dtype , 'idx.index.dtype' )
1154+
1155+ if not self_index_is_none and not idx_index_is_none :
1156+ if not check_types_comparable (self .index , idx .index ):
1157+ ty_checker .raise_exc (idx .index .dtype , self .index .dtype , 'idx.index.dtype' )
1158+
1159+ return gen_df_getitem_bool_series_idx_impl (self , idx )
1160+
10911161 ty_checker .raise_exc (idx , 'str' , 'idx' )
10921162
10931163
0 commit comments