|
1 | 1 | from xdsl.builder import Builder |
2 | | -from xdsl.dialects.builtin import IntegerAttr, IntegerType, MemRefType, TensorType, f32 |
| 2 | +from xdsl.dialects.builtin import ( |
| 3 | + IntAttr, |
| 4 | + IntegerAttr, |
| 5 | + IntegerType, |
| 6 | + MemRefType, |
| 7 | + TensorType, |
| 8 | + f32, |
| 9 | +) |
3 | 10 | from xdsl.dialects.csl.csl_stencil import AccessOp, ApplyOp |
4 | 11 | from xdsl.dialects.stencil import IndexAttr, TempType |
5 | 12 | from xdsl.ir import Region, SSAValue |
6 | 13 | from xdsl.utils.test_value import create_ssa_value |
7 | 14 |
|
8 | 15 |
|
9 | 16 | def test_access_patterns(): |
10 | | - temp_t = TempType(5, f32) |
| 17 | + temp_t = TempType(IntAttr(5), f32) |
11 | 18 | temp = create_ssa_value(temp_t) |
12 | 19 | mref = create_ssa_value(mref_t := MemRefType(tens_t := TensorType(f32, (5,)), (4,))) |
13 | 20 |
|
14 | 21 | @Builder.implicit_region((mref_t, temp_t)) |
15 | 22 | def region0(args: tuple[SSAValue, ...]): |
16 | 23 | t0, t1 = args |
17 | 24 | for x in (-1, 1): |
18 | | - AccessOp(t0, IndexAttr.get(x, 0), tens_t) |
| 25 | + AccessOp(t0, IndexAttr.from_indices(x, 0), tens_t) |
19 | 26 | for y in (-1, 1): |
20 | | - AccessOp(t0, IndexAttr.get(0, y), tens_t) |
| 27 | + AccessOp(t0, IndexAttr.from_indices(0, y), tens_t) |
21 | 28 |
|
22 | | - AccessOp(t1, IndexAttr.get(1, 1), tens_t) |
23 | | - AccessOp(t1, IndexAttr.get(-1, -1), tens_t) |
| 29 | + AccessOp(t1, IndexAttr.from_indices(1, 1), tens_t) |
| 30 | + AccessOp(t1, IndexAttr.from_indices(-1, -1), tens_t) |
24 | 31 |
|
25 | 32 | @Builder.implicit_region((temp_t, temp_t)) |
26 | 33 | def region1(args: tuple[SSAValue, ...]): |
27 | 34 | t0, t1 = args |
28 | | - AccessOp(t0, IndexAttr.get(0, 0), tens_t) |
29 | | - AccessOp(t1, IndexAttr.get(0, 0), tens_t) |
| 35 | + AccessOp(t0, IndexAttr.from_indices(0, 0), tens_t) |
| 36 | + AccessOp(t1, IndexAttr.from_indices(0, 0), tens_t) |
30 | 37 |
|
31 | 38 | apply = ApplyOp( |
32 | 39 | operands=[temp, mref, [], [], []], |
|
0 commit comments