Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 2e26f6f

Browse files
author
Ivan Butygin
authored
Revert "WIP: Port to numba master (#338)" (#342)
This reverts commit e3f53bc.
1 parent 7cc8733 commit 2e26f6f

13 files changed

Lines changed: 63 additions & 24 deletions

sdc/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1770,7 +1770,7 @@ def _gen_parfor_reductions(self, parfor, namevar_table):
17701770
_, reductions = get_parfor_reductions(
17711771
parfor, parfor.params, self.state.calltypes)
17721772

1773-
for reduce_varname, (init_val, reduce_nodes, _) in reductions.items():
1773+
for reduce_varname, (init_val, reduce_nodes) in reductions.items():
17741774
reduce_op = guard(self._get_reduce_op, reduce_nodes)
17751775
# TODO: initialize reduction vars (arrays)
17761776
reduce_var = namevar_table[reduce_varname]

sdc/hiframes/aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def aggregate_array_analysis(aggregate_node, equiv_set, typemap,
438438
equiv_set.insert_equiv(col_var, shape)
439439
post.extend(c_post)
440440
all_shapes.append(shape[0])
441-
equiv_set.define(col_var, {})
441+
equiv_set.define(col_var)
442442

443443
if len(all_shapes) > 1:
444444
equiv_set.insert_equiv(*all_shapes)

sdc/hiframes/dataframe_pass.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ def run_pass(self):
140140
out_nodes = [inst]
141141

142142
if isinstance(inst, ir.Assign):
143-
if inst.value in self.state.func_ir._definitions[inst.target.name]:
144-
self.state.func_ir._definitions[inst.target.name].remove(inst.value)
143+
self.state.func_ir._definitions[inst.target.name].remove(inst.value)
145144
out_nodes = self._run_assign(inst)
146145
elif isinstance(inst, (ir.SetItem, ir.StaticSetItem)):
147146
out_nodes = self._run_setitem(inst)

sdc/hiframes/filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def filter_array_analysis(filter_node, equiv_set, typemap, array_analysis):
100100
equiv_set.insert_equiv(col_var, shape)
101101
post.extend(c_post)
102102
all_shapes.append(shape[0])
103-
equiv_set.define(col_var, {})
103+
equiv_set.define(col_var)
104104

105105
if len(all_shapes) > 1:
106106
equiv_set.insert_equiv(*all_shapes)

sdc/hiframes/hiframes_typed.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,7 +1212,10 @@ def _handle_series_map(self, assign, lhs, rhs, series_var):
12121212
# error checking: make sure there is function input only
12131213
if len(rhs.args) != 1:
12141214
raise ValueError("map expects 1 argument")
1215-
func = guard(get_definition, self.state.func_ir, rhs.args[0]).value.py_func
1215+
func = guard(get_definition, self.state.func_ir, rhs.args[0])
1216+
if func is None or not (isinstance(func, ir.Expr)
1217+
and func.op == 'make_function'):
1218+
raise ValueError("lambda for map not found")
12161219

12171220
dtype = self.state.typemap[series_var.name].dtype
12181221
nodes = []
@@ -1379,7 +1382,11 @@ def _handle_series_combine(self, assign, lhs, rhs, series_var):
13791382
raise ValueError("not enough arguments in call to combine")
13801383
if len(rhs.args) > 3:
13811384
raise ValueError("too many arguments in call to combine")
1382-
func = guard(get_definition, self.state.func_ir, rhs.args[1]).value.py_func
1385+
func = guard(get_definition, self.state.func_ir, rhs.args[1])
1386+
if func is None or not (isinstance(func, ir.Expr)
1387+
and func.op == 'make_function'):
1388+
raise ValueError("lambda for combine not found")
1389+
13831390
out_typ = self.state.typemap[lhs.name].dtype
13841391
other = rhs.args[0]
13851392
nodes = []
@@ -1526,16 +1533,19 @@ def f(arr, w, center): # pragma: no cover
15261533
def _handle_rolling_apply_func(self, func_node, dtype, out_dtype):
15271534
if func_node is None:
15281535
raise ValueError("cannot find kernel function for rolling.apply() call")
1529-
func_node = func_node.value.py_func
15301536
# TODO: more error checking on the kernel to make sure it doesn't
15311537
# use global/closure variables
1538+
if func_node.closure is not None:
1539+
raise ValueError("rolling apply kernel functions cannot have closure variables")
1540+
if func_node.defaults is not None:
1541+
raise ValueError("rolling apply kernel functions cannot have default arguments")
15321542
# create a function from the code object
15331543
glbs = self.state.func_ir.func_id.func.__globals__
15341544
lcs = {}
15351545
exec("def f(A): return A", glbs, lcs)
15361546
kernel_func = lcs['f']
1537-
kernel_func.__code__ = func_node.__code__
1538-
kernel_func.__name__ = func_node.__code__.co_name
1547+
kernel_func.__code__ = func_node.code
1548+
kernel_func.__name__ = func_node.code.co_name
15391549
# use hpat's sequential pipeline to enable pandas operations
15401550
# XXX seq pipeline used since dist pass causes a hang
15411551
m = numba.ir_utils._max_label

sdc/hiframes/join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def join_array_analysis(join_node, equiv_set, typemap, array_analysis):
131131
equiv_set.insert_equiv(col_var, shape)
132132
post.extend(c_post)
133133
all_shapes.append(shape[0])
134-
equiv_set.define(col_var, {})
134+
equiv_set.define(col_var)
135135

136136
if len(all_shapes) > 1:
137137
equiv_set.insert_equiv(*all_shapes)

sdc/hiframes/pd_dataframe_ext.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def resolve_values(self, ary):
142142
def resolve_apply(self, df, args, kws):
143143
kws = dict(kws)
144144
func = args[0] if len(args) > 0 else kws.get('func', None)
145+
# check lambda
146+
if not isinstance(func, types.MakeFunctionLiteral):
147+
raise ValueError("df.apply(): lambda not found")
148+
145149
# check axis
146150
axis = args[1] if len(args) > 1 else kws.get('axis', None)
147151
if (axis is None or not isinstance(axis, types.IntegerLiteral)
@@ -161,8 +165,12 @@ def resolve_apply(self, df, args, kws):
161165
dtypes.append(el_typ)
162166

163167
row_typ = types.NamedTuple(dtypes, Row)
164-
t = func.get_call_type(self.context, (row_typ,), {})
165-
return signature(SeriesType(t.return_type), *args)
168+
code = func.literal_value.code
169+
f_ir = numba.ir_utils.get_ir_of_code({'np': np}, code)
170+
_, f_return_type, _ = numba.typed_passes.type_inference_stage(
171+
self.context, f_ir, (row_typ,), None)
172+
173+
return signature(SeriesType(f_return_type), *args)
166174

167175
@bound_function("df.describe")
168176
def resolve_describe(self, df, args, kws):

sdc/hiframes/pd_series_ext.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,8 +564,18 @@ def _resolve_map_func(self, ary, args, kws):
564564
# getitem returns Timestamp for dt_index and series(dt64)
565565
if dtype == types.NPDatetime('ns'):
566566
dtype = pandas_timestamp_type
567-
t = args[0].get_call_type(self.context, (dtype,), {})
568-
return signature(SeriesType(t.return_type), *args)
567+
code = args[0].literal_value.code
568+
_globals = {'np': np}
569+
# XXX hack in hiframes_typed to make globals available
570+
if hasattr(args[0].literal_value, 'globals'):
571+
# TODO: use code.co_names to find globals actually used?
572+
_globals = args[0].literal_value.globals
573+
574+
f_ir = numba.ir_utils.get_ir_of_code(_globals, code)
575+
f_typemap, f_return_type, f_calltypes = numba.typed_passes.type_inference_stage(
576+
self.context, f_ir, (dtype,), None)
577+
578+
return signature(SeriesType(f_return_type), *args)
569579

570580
@bound_function("series.map")
571581
def resolve_map(self, ary, args, kws):
@@ -584,8 +594,11 @@ def _resolve_combine_func(self, ary, args, kws):
584594
dtype2 = args[0].dtype
585595
if dtype2 == types.NPDatetime('ns'):
586596
dtype2 = pandas_timestamp_type
587-
t = args[1].get_call_type(self.context, (dtype1, dtype2,), {})
588-
return signature(SeriesType(t.return_type), *args)
597+
code = args[1].literal_value.code
598+
f_ir = numba.ir_utils.get_ir_of_code({'np': np}, code)
599+
f_typemap, f_return_type, f_calltypes = numba.typed_passes.type_inference_stage(
600+
self.context, f_ir, (dtype1, dtype2,), None)
601+
return signature(SeriesType(f_return_type), *args)
589602

590603
@bound_function("series.combine")
591604
def resolve_combine(self, ary, args, kws):

sdc/io/csv_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def csv_array_analysis(csv_node, equiv_set, typemap, array_analysis):
9393
equiv_set.insert_equiv(col_var, shape)
9494
post.extend(c_post)
9595
all_shapes.append(shape[0])
96-
equiv_set.define(col_var, {})
96+
equiv_set.define(col_var)
9797

9898
if len(all_shapes) > 1:
9999
equiv_set.insert_equiv(*all_shapes)

sdc/tests/test_basic.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ def test_array_reduce(self):
327327
self.assertEqual(count_array_OneDs(), 0)
328328
self.assertEqual(count_parfor_OneDs(), 1)
329329

330-
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
330+
@unittest.skipIf(check_numba_version('0.46.0'),
331+
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
331332
def test_dist_return(self):
332333
def test_impl(N):
333334
A = np.arange(N)
@@ -344,7 +345,8 @@ def test_impl(N):
344345
self.assertEqual(count_array_OneDs(), 1)
345346
self.assertEqual(count_parfor_OneDs(), 1)
346347

347-
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
348+
@unittest.skipIf(check_numba_version('0.46.0'),
349+
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
348350
def test_dist_return_tuple(self):
349351
def test_impl(N):
350352
A = np.arange(N)
@@ -373,7 +375,8 @@ def test_impl(A):
373375
np.testing.assert_allclose(hpat_func(arr) / self.num_ranks, test_impl(arr))
374376
self.assertEqual(count_array_OneDs(), 1)
375377

376-
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
378+
@unittest.skipIf(check_numba_version('0.46.0'),
379+
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
377380
def test_rebalance(self):
378381
def test_impl(N):
379382
A = np.arange(n)
@@ -391,7 +394,8 @@ def test_impl(N):
391394
finally:
392395
sdc.distributed_analysis.auto_rebalance = False
393396

394-
@unittest.expectedFailure # https://github.com/numba/numba/issues/4690
397+
@unittest.skipIf(check_numba_version('0.46.0'),
398+
"Broken in numba 0.46.0. https://github.com/numba/numba/issues/4690")
395399
def test_rebalance_loop(self):
396400
def test_impl(N):
397401
A = np.arange(n)

0 commit comments

Comments
 (0)