Skip to content

Commit b495b99

Browse files
authored
transformations: (convert-pdl-to-pdl-interp) Insert dedups for operations returned by native rewrites (#5699)
1 parent 583e55b commit b495b99

2 files changed

Lines changed: 47 additions & 1 deletion

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: xdsl-opt %s -p convert-pdl-to-pdl-interp{optimize_for_eqsat=true} | filecheck %s
2+
3+
4+
// CHECK: pdl_interp.func @matcher(%0 : !pdl.operation) {
5+
6+
//...
7+
8+
9+
// CHECK: builtin.module @rewriters {
10+
// CHECK-NEXT: pdl_interp.func @pdl_generated_rewriter(%0 : !pdl.operation, %1 : !pdl.value) {
11+
// CHECK-NEXT: %2, %3, %4 = pdl_interp.apply_rewrite "constraint_returning_op"(%0 : !pdl.operation) : !pdl.operation, !pdl.type, !pdl.operation
12+
// CHECK-NEXT: %5 = ematch.dedup %2
13+
// CHECK-NEXT: %6 = ematch.dedup %4
14+
// CHECK-NEXT: %7 = ematch.get_class_result %1
15+
// CHECK-NEXT: %8 = pdl_interp.create_range %7 : !pdl.value
16+
// CHECK-NEXT: ematch.union %0 : !pdl.operation, %8 : !pdl.range<value>
17+
// CHECK-NEXT: pdl_interp.finalize
18+
// CHECK-NEXT: }
19+
// CHECK-NEXT: }
20+
21+
22+
pdl.pattern : benefit(1) {
23+
%x = pdl.operand
24+
%type = pdl.type
25+
%one = pdl.attribute = 1 : i32
26+
%constop = pdl.operation "arith.constant" {"value" = %one} -> (%type : !pdl.type)
27+
%const = pdl.result 0 of %constop
28+
%mulop = pdl.operation "arith.muli" (%x, %const : !pdl.value, !pdl.value) -> (%type : !pdl.type)
29+
pdl.rewrite %mulop {
30+
%op1, %t, %op2 = pdl.apply_native_rewrite "constraint_returning_op"(%mulop : !pdl.operation) : !pdl.operation, !pdl.type, !pdl.operation
31+
pdl.replace %mulop with (%x : !pdl.value)
32+
}
33+
}

xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2192,7 +2192,20 @@ def _generate_rewriter_for_apply_native_rewrite(
21922192
op.constraint_name, arguments, result_types
21932193
)
21942194
self.rewriter_builder.insert(interp_op)
2195-
for old_res, new_res in zip(op.results, interp_op.results, strict=True):
2195+
new_results = interp_op.results
2196+
if self.optimize_for_eqsat:
2197+
# In order for equality saturation to work correctly, operations muste be deduplicated.
2198+
# This includes operations created by native rewrites. Whenever a native rewrite
2199+
# creates an operation, it should be returned by the native rewrite. Here we insert
2200+
# dedup ops for each operation that is returned by the native rewrite:
2201+
new_results = [
2202+
self.rewriter_builder.insert(ematch.DedupOp(new_res)).result_op
2203+
if isinstance(new_res.type, pdl.OperationType)
2204+
else new_res
2205+
for new_res in new_results
2206+
]
2207+
2208+
for old_res, new_res in zip(op.results, new_results, strict=True):
21962209
rewrite_values[old_res] = new_res
21972210

21982211
def _generate_rewriter_for_attribute(

0 commit comments

Comments
 (0)