Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit e3f53bc

Browse files
Hardcode84shssf
authored andcommitted
WIP: Port to numba master (#338)
* remove boost.regex dependency * adapt for get_parfor_reductions interface changes * disable tests due dead code_parfor regression * lambda type_infer quickfix * quick fix * fix define sig * old style: fixes for lambda inlining * fix series combine * some work on df.apply * fix rolling * remove commented code * expected failures * style
1 parent 10dd390 commit e3f53bc

13 files changed

Lines changed: 24 additions & 63 deletions

sdc/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1770,7 +1770,7 @@ def _gen_parfor_reductions(self, parfor, namevar_table):
17701770
_, reductions = get_parfor_reductions(
17711771
parfor, parfor.params, self.state.calltypes)
17721772

1773-
for reduce_varname, (init_val, reduce_nodes) in reductions.items():
1773+
for reduce_varname, (init_val, reduce_nodes, _) in reductions.items():
17741774
reduce_op = guard(self._get_reduce_op, reduce_nodes)
17751775
# TODO: initialize reduction vars (arrays)
17761776
reduce_var = namevar_table[reduce_varname]

sdc/hiframes/aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def aggregate_array_analysis(aggregate_node, equiv_set, typemap,
438438
equiv_set.insert_equiv(col_var, shape)
439439
post.extend(c_post)
440440
all_shapes.append(shape[0])
441-
equiv_set.define(col_var)
441+
equiv_set.define(col_var, {})
442442

443443
if len(all_shapes) > 1:
444444
equiv_set.insert_equiv(*all_shapes)

sdc/hiframes/dataframe_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def run_pass(self):
140140
out_nodes = [inst]
141141

142142
if isinstance(inst, ir.Assign):
143-
self.state.func_ir._definitions[inst.target.name].remove(inst.value)
143+
if inst.value in self.state.func_ir._definitions[inst.target.name]:
144+
self.state.func_ir._definitions[inst.target.name].remove(inst.value)
144145
out_nodes = self._run_assign(inst)
145146
elif isinstance(inst, (ir.SetItem, ir.StaticSetItem)):
146147
out_nodes = self._run_setitem(inst)

sdc/hiframes/filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def filter_array_analysis(filter_node, equiv_set, typemap, array_analysis):
100100
equiv_set.insert_equiv(col_var, shape)
101101
post.extend(c_post)
102102
all_shapes.append(shape[0])
103-
equiv_set.define(col_var)
103+
equiv_set.define(col_var, {})
104104

105105
if len(all_shapes) > 1:
106106
equiv_set.insert_equiv(*all_shapes)

sdc/hiframes/hiframes_typed.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,10 +1212,7 @@ def _handle_series_map(self, assign, lhs, rhs, series_var):
12121212
# error checking: make sure there is function input only
12131213
if len(rhs.args) != 1:
12141214
raise ValueError("map expects 1 argument")
1215-
func = guard(get_definition, self.state.func_ir, rhs.args[0])
1216-
if func is None or not (isinstance(func, ir.Expr)
1217-
and func.op == 'make_function'):
1218-
raise ValueError("lambda for map not found")
1215+
func = guard(get_definition, self.state.func_ir, rhs.args[0]).value.py_func
12191216

12201217
dtype = self.state.typemap[series_var.name].dtype
12211218
nodes = []
@@ -1382,11 +1379,7 @@ def _handle_series_combine(self, assign, lhs, rhs, series_var):
13821379
raise ValueError("not enough arguments in call to combine")
13831380
if len(rhs.args) > 3:
13841381
raise ValueError("too many arguments in call to combine")
1385-
func = guard(get_definition, self.state.func_ir, rhs.args[1])
1386-
if func is None or not (isinstance(func, ir.Expr)
1387-
and func.op == 'make_function'):
1388-
raise ValueError("lambda for combine not found")
1389-
1382+
func = guard(get_definition, self.state.func_ir, rhs.args[1]).value.py_func
13901383
out_typ = self.state.typemap[lhs.name].dtype
13911384
other = rhs.args[0]
13921385
nodes = []
@@ -1533,19 +1526,16 @@ def f(arr, w, center): # pragma: no cover
15331526
def _handle_rolling_apply_func(self, func_node, dtype, out_dtype):
15341527
if func_node is None:
15351528
raise ValueError("cannot find kernel function for rolling.apply() call")
1529+
func_node = func_node.value.py_func
15361530
# TODO: more error checking on the kernel to make sure it doesn't
15371531
# use global/closure variables
1538-
if func_node.closure is not None:
1539-
raise ValueError("rolling apply kernel functions cannot have closure variables")
1540-
if func_node.defaults is not None:
1541-
raise ValueError("rolling apply kernel functions cannot have default arguments")
15421532
# create a function from the code object
15431533
glbs = self.state.func_ir.func_id.func.__globals__
15441534
lcs = {}
15451535
exec("def f(A): return A", glbs, lcs)
15461536
kernel_func = lcs['f']
1547-
kernel_func.__code__ = func_node.code
1548-
kernel_func.__name__ = func_node.code.co_name
1537+
kernel_func.__code__ = func_node.__code__
1538+
kernel_func.__name__ = func_node.__code__.co_name
15491539
# use hpat's sequential pipeline to enable pandas operations
15501540
# XXX seq pipeline used since dist pass causes a hang
15511541
m = numba.ir_utils._max_label

sdc/hiframes/join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def join_array_analysis(join_node, equiv_set, typemap, array_analysis):
131131
equiv_set.insert_equiv(col_var, shape)
132132
post.extend(c_post)
133133
all_shapes.append(shape[0])
134-
equiv_set.define(col_var)
134+
equiv_set.define(col_var, {})
135135

136136
if len(all_shapes) > 1:
137137
equiv_set.insert_equiv(*all_shapes)

sdc/hiframes/pd_dataframe_ext.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,6 @@ def resolve_values(self, ary):
142142
def resolve_apply(self, df, args, kws):
143143
kws = dict(kws)
144144
func = args[0] if len(args) > 0 else kws.get('func', None)
145-
# check lambda
146-
if not isinstance(func, types.MakeFunctionLiteral):
147-
raise ValueError("df.apply(): lambda not found")
148-
149145
# check axis
150146
axis = args[1] if len(args) > 1 else kws.get('axis', None)
151147
if (axis is None or not isinstance(axis, types.IntegerLiteral)
@@ -165,12 +161,8 @@ def resolve_apply(self, df, args, kws):
165161
dtypes.append(el_typ)
166162

167163
row_typ = types.NamedTuple(dtypes, Row)
168-
code = func.literal_value.code
169-
f_ir = numba.ir_utils.get_ir_of_code({'np': np}, code)
170-
_, f_return_type, _ = numba.typed_passes.type_inference_stage(
171-
self.context, f_ir, (row_typ,), None)
172-
173-
return signature(SeriesType(f_return_type), *args)
164+
t = func.get_call_type(self.context, (row_typ,), {})
165+
return signature(SeriesType(t.return_type), *args)
174166

175167
@bound_function("df.describe")
176168
def resolve_describe(self, df, args, kws):

sdc/hiframes/pd_series_ext.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -564,18 +564,8 @@ def _resolve_map_func(self, ary, args, kws):
564564
# getitem returns Timestamp for dt_index and series(dt64)
565565
if dtype == types.NPDatetime('ns'):
566566
dtype = pandas_timestamp_type
567-
code = args[0].literal_value.code
568-
_globals = {'np': np}
569-
# XXX hack in hiframes_typed to make globals available
570-
if hasattr(args[0].literal_value, 'globals'):
571-
# TODO: use code.co_names to find globals actually used?
572-
_globals = args[0].literal_value.globals
573-
574-
f_ir = numba.ir_utils.get_ir_of_code(_globals, code)
575-
f_typemap, f_return_type, f_calltypes = numba.typed_passes.type_inference_stage(
576-
self.context, f_ir, (dtype,), None)
577-
578-
return signature(SeriesType(f_return_type), *args)
567+
t = args[0].get_call_type(self.context, (dtype,), {})
568+
return signature(SeriesType(t.return_type), *args)
579569

580570
@bound_function("series.map")
581571
def resolve_map(self, ary, args, kws):
@@ -594,11 +584,8 @@ def _resolve_combine_func(self, ary, args, kws):
594584
dtype2 = args[0].dtype
595585
if dtype2 == types.NPDatetime('ns'):
596586
dtype2 = pandas_timestamp_type
597-
code = args[1].literal_value.code
598-
f_ir = numba.ir_utils.get_ir_of_code({'np': np}, code)
599-
f_typemap, f_return_type, f_calltypes = numba.typed_passes.type_inference_stage(
600-
self.context, f_ir, (dtype1, dtype2,), None)
601-
return signature(SeriesType(f_return_type), *args)
587+
t = args[1].get_call_type(self.context, (dtype1, dtype2,), {})
588+
return signature(SeriesType(t.return_type), *args)
602589

603590
@bound_function("series.combine")
604591
def resolve_combine(self, ary, args, kws):

sdc/io/csv_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def csv_array_analysis(csv_node, equiv_set, typemap, array_analysis):
9393
equiv_set.insert_equiv(col_var, shape)
9494
post.extend(c_post)
9595
all_shapes.append(shape[0])
96-
equiv_set.define(col_var)
96+
equiv_set.define(col_var, {})
9797

9898
if len(all_shapes) > 1:
9999
equiv_set.insert_equiv(*all_shapes)

sdc/tests/test_basic.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ def test_array_reduce(self):
327327
self.assertEqual(count_array_OneDs(), 0)
328328
self.assertEqual(count_parfor_OneDs(), 1)
329329

330-
@unittest.skipIf(check_numba_version('0.46.0'),
331-
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
330+
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
332331
def test_dist_return(self):
333332
def test_impl(N):
334333
A = np.arange(N)
@@ -345,8 +344,7 @@ def test_impl(N):
345344
self.assertEqual(count_array_OneDs(), 1)
346345
self.assertEqual(count_parfor_OneDs(), 1)
347346

348-
@unittest.skipIf(check_numba_version('0.46.0'),
349-
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
347+
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
350348
def test_dist_return_tuple(self):
351349
def test_impl(N):
352350
A = np.arange(N)
@@ -375,8 +373,7 @@ def test_impl(A):
375373
np.testing.assert_allclose(hpat_func(arr) / self.num_ranks, test_impl(arr))
376374
self.assertEqual(count_array_OneDs(), 1)
377375

378-
@unittest.skipIf(check_numba_version('0.46.0'),
379-
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
376+
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
380377
def test_rebalance(self):
381378
def test_impl(N):
382379
A = np.arange(n)
@@ -394,8 +391,7 @@ def test_impl(N):
394391
finally:
395392
sdc.distributed_analysis.auto_rebalance = False
396393

397-
@unittest.skipIf(check_numba_version('0.46.0'),
398-
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
394+
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
399395
def test_rebalance_loop(self):
400396
def test_impl(N):
401397
A = np.arange(n)

0 commit comments

Comments
 (0)