Skip to content

Commit 5b65cd7

Browse files
committed
GROOVY-11890: groovy-contracts could support "@decreases"
1 parent 308270c commit 5b65cd7

4 files changed

Lines changed: 526 additions & 0 deletions

File tree

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package groovy.contracts;
20+
21+
import groovy.lang.annotation.ExtendedElementType;
22+
import groovy.lang.annotation.ExtendedTarget;
23+
import org.apache.groovy.lang.annotation.Incubating;
24+
import org.codehaus.groovy.transform.GroovyASTTransformationClass;
25+
26+
import java.lang.annotation.ElementType;
27+
import java.lang.annotation.Retention;
28+
import java.lang.annotation.RetentionPolicy;
29+
import java.lang.annotation.Target;
30+
31+
/**
32+
* Specifies a termination measure for a loop. The closure must return a
33+
* {@link Comparable} value that <em>strictly decreases</em> on every
34+
* iteration and remains non-negative (i.e., {@code >= 0} for numeric types).
35+
* <p>
36+
* At runtime, the expression is evaluated at the start and end of each
37+
* iteration. A {@link org.apache.groovy.contracts.LoopVariantViolation
38+
* LoopVariantViolation} is thrown if:
39+
* <ul>
40+
* <li>the value did not decrease, or</li>
41+
* <li>the value became negative.</li>
42+
* </ul>
43+
* <p>
44+
* Example:
45+
* <pre>
46+
* int n = 10
47+
* {@code @Decreases}({ n })
48+
* while (n &gt; 0) {
49+
* n--
50+
* }
51+
* </pre>
52+
*
53+
* @since 6.0.0
54+
* @see Invariant
55+
*/
56+
@Retention(RetentionPolicy.RUNTIME)
57+
@Target(ElementType.TYPE)
58+
@ExtendedTarget(ExtendedElementType.LOOP)
59+
@Incubating
60+
@GroovyASTTransformationClass("org.apache.groovy.contracts.ast.LoopVariantASTTransformation")
61+
public @interface Decreases {
62+
Class value();
63+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.groovy.contracts;
20+
21+
/**
22+
* Thrown whenever a loop variant (decreases/increases) violation occurs.
23+
*
24+
* @see AssertionViolation
25+
* @since 6.0.0
26+
*/
27+
public class LoopVariantViolation extends AssertionViolation {
28+
29+
public LoopVariantViolation() {
30+
}
31+
32+
public LoopVariantViolation(Object o) {
33+
super(o);
34+
}
35+
36+
public LoopVariantViolation(boolean b) {
37+
super(b);
38+
}
39+
40+
public LoopVariantViolation(char c) {
41+
super(c);
42+
}
43+
44+
public LoopVariantViolation(int i) {
45+
super(i);
46+
}
47+
48+
public LoopVariantViolation(long l) {
49+
super(l);
50+
}
51+
52+
public LoopVariantViolation(float f) {
53+
super(f);
54+
}
55+
56+
public LoopVariantViolation(double d) {
57+
super(d);
58+
}
59+
}
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.groovy.contracts.ast;
20+
21+
import groovy.contracts.Decreases;
22+
import org.apache.groovy.contracts.LoopVariantViolation;
23+
import org.codehaus.groovy.ast.ASTNode;
24+
import org.codehaus.groovy.ast.AnnotationNode;
25+
import org.codehaus.groovy.ast.ClassHelper;
26+
import org.codehaus.groovy.ast.expr.BooleanExpression;
27+
import org.codehaus.groovy.ast.expr.ClosureExpression;
28+
import org.codehaus.groovy.ast.expr.Expression;
29+
import org.codehaus.groovy.ast.stmt.BlockStatement;
30+
import org.codehaus.groovy.ast.stmt.ExpressionStatement;
31+
import org.codehaus.groovy.ast.stmt.LoopingStatement;
32+
import org.codehaus.groovy.ast.stmt.Statement;
33+
import org.codehaus.groovy.control.CompilePhase;
34+
import org.codehaus.groovy.control.SourceUnit;
35+
import org.codehaus.groovy.transform.ASTTransformation;
36+
import org.codehaus.groovy.transform.GroovyASTTransformation;
37+
38+
import java.util.List;
39+
import java.util.Objects;
40+
import java.util.concurrent.atomic.AtomicLong;
41+
42+
import static org.codehaus.groovy.ast.tools.GeneralUtils.args;
43+
import static org.codehaus.groovy.ast.tools.GeneralUtils.assignS;
44+
import static org.codehaus.groovy.ast.tools.GeneralUtils.block;
45+
import static org.codehaus.groovy.ast.tools.GeneralUtils.boolX;
46+
import static org.codehaus.groovy.ast.tools.GeneralUtils.callX;
47+
import static org.codehaus.groovy.ast.tools.GeneralUtils.constX;
48+
import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX;
49+
import static org.codehaus.groovy.ast.tools.GeneralUtils.declS;
50+
import static org.codehaus.groovy.ast.tools.GeneralUtils.geX;
51+
import static org.codehaus.groovy.ast.tools.GeneralUtils.localVarX;
52+
import static org.codehaus.groovy.ast.tools.GeneralUtils.ltX;
53+
import static org.codehaus.groovy.ast.tools.GeneralUtils.stmt;
54+
import static org.codehaus.groovy.ast.tools.GeneralUtils.throwS;
55+
import static org.codehaus.groovy.ast.tools.GeneralUtils.varX;
56+
57+
/**
58+
* Handles {@link Decreases} annotations placed on loop statements ({@code for},
59+
* {@code while}, {@code do-while}). The closure must return a value that
60+
* strictly decreases on every iteration and remains non-negative.
61+
* <p>
62+
* The transformation injects code to:
63+
* <ol>
64+
* <li>Save the expression value at the start of each iteration.</li>
65+
* <li>Re-evaluate it at the end of the iteration.</li>
66+
* <li>Assert the value has strictly decreased.</li>
67+
* <li>Assert the value is non-negative.</li>
68+
* </ol>
69+
* <p>
70+
* Example:
71+
* <pre>
72+
* int n = 10
73+
* {@code @Decreases}({ n })
74+
* while (n &gt; 0) {
75+
* n--
76+
* }
77+
* </pre>
78+
*
79+
* @since 6.0.0
80+
* @see Decreases
81+
* @see LoopVariantViolation
82+
*/
83+
@GroovyASTTransformation(phase = CompilePhase.SEMANTIC_ANALYSIS)
84+
public class LoopVariantASTTransformation implements ASTTransformation {
85+
86+
private static final AtomicLong COUNTER = new AtomicLong();
87+
88+
@Override
89+
public void visit(final ASTNode[] nodes, final SourceUnit source) {
90+
if (nodes.length != 2) return;
91+
if (!(nodes[0] instanceof AnnotationNode annotation)) return;
92+
if (!(nodes[1] instanceof LoopingStatement loopStatement)) return;
93+
94+
Expression value = annotation.getMember("value");
95+
if (!(value instanceof ClosureExpression closureExpression)) return;
96+
97+
Expression variantExpression = extractExpression(closureExpression);
98+
if (variantExpression == null) return;
99+
100+
String suffix = Long.toString(COUNTER.getAndIncrement());
101+
String prevVarName = "$_gc_decreases_prev_" + suffix;
102+
String currVarName = "$_gc_decreases_curr_" + suffix;
103+
104+
// At start of iteration: def prevVar = <expression>
105+
Statement savePrev = declS(localVarX(prevVarName, ClassHelper.dynamicType()), variantExpression);
106+
savePrev.setSourcePosition(annotation);
107+
108+
// At end of iteration: def currVar = <expression copy>
109+
// We need a fresh copy of the expression for re-evaluation
110+
Expression variantCopy = copyExpression(closureExpression);
111+
Statement saveCurr = declS(localVarX(currVarName, ClassHelper.dynamicType()), variantCopy);
112+
saveCurr.setSourcePosition(annotation);
113+
114+
// Assert: currVar < prevVar (must strictly decrease)
115+
Statement decreaseCheck = stmt(
116+
callX(
117+
ClassHelper.makeWithoutCaching(LoopVariantASTTransformation.class),
118+
"checkDecreased",
119+
args(varX(prevVarName), varX(currVarName))
120+
)
121+
);
122+
decreaseCheck.setSourcePosition(annotation);
123+
124+
// Inject: save at start, check at end
125+
injectAtLoopBodyStartAndEnd(loopStatement, savePrev, block(saveCurr, decreaseCheck));
126+
}
127+
128+
/**
129+
* Runtime check called from generated code. Throws {@link LoopVariantViolation}
130+
* if the variant did not strictly decrease or became negative.
131+
* <p>
132+
* If both values are {@link List}s, they are compared lexicographically:
133+
* the first position where values differ must show a strict decrease;
134+
* all earlier positions must be equal. If all positions are equal, the
135+
* variant has not decreased and a violation is thrown.
136+
*/
137+
public static void checkDecreased(Object prev, Object curr) {
138+
if (prev instanceof List<?> prevList && curr instanceof List<?> currList) {
139+
checkDecreasedLexicographic(prevList, currList);
140+
} else if (prev instanceof Comparable && curr instanceof Comparable) {
141+
checkDecreasedScalar(prev, curr);
142+
} else {
143+
throw new LoopVariantViolation(
144+
"<groovy.contracts.Decreases> loop variant is not Comparable: prev=" + prev + ", curr=" + curr);
145+
}
146+
}
147+
148+
@SuppressWarnings("unchecked")
149+
private static void checkDecreasedScalar(Object prev, Object curr) {
150+
Comparable<Object> prevComp = (Comparable<Object>) prev;
151+
if (prevComp.compareTo(curr) <= 0) {
152+
throw new LoopVariantViolation(
153+
"<groovy.contracts.Decreases> loop variant did not decrease: was " + prev + ", now " + curr);
154+
}
155+
if (curr instanceof Number && ((Number) curr).doubleValue() < 0) {
156+
throw new LoopVariantViolation(
157+
"<groovy.contracts.Decreases> loop variant became negative: " + curr);
158+
}
159+
}
160+
161+
@SuppressWarnings("unchecked")
162+
private static void checkDecreasedLexicographic(List<?> prev, List<?> curr) {
163+
int size = Math.min(prev.size(), curr.size());
164+
for (int i = 0; i < size; i++) {
165+
Object p = prev.get(i);
166+
Object c = curr.get(i);
167+
if (!(p instanceof Comparable) || !(c instanceof Comparable)) {
168+
throw new LoopVariantViolation(
169+
"<groovy.contracts.Decreases> loop variant element at position " + i
170+
+ " is not Comparable: prev=" + p + ", curr=" + c);
171+
}
172+
int cmp = ((Comparable<Object>) p).compareTo(c);
173+
if (cmp > 0) {
174+
// This element decreased — lexicographic comparison satisfied
175+
return;
176+
}
177+
if (cmp < 0) {
178+
throw new LoopVariantViolation(
179+
"<groovy.contracts.Decreases> loop variant increased at position " + i
180+
+ ": was " + prev + ", now " + curr);
181+
}
182+
// cmp == 0: equal at this position, check next
183+
}
184+
// All compared positions are equal — no progress
185+
throw new LoopVariantViolation(
186+
"<groovy.contracts.Decreases> loop variant did not decrease: was " + prev + ", now " + curr);
187+
}
188+
189+
private static Expression extractExpression(ClosureExpression closureExpression) {
190+
BlockStatement block = (BlockStatement) closureExpression.getCode();
191+
List<Statement> statements = block.getStatements();
192+
if (statements.size() != 1) return null;
193+
Statement stmt = statements.get(0);
194+
if (stmt instanceof ExpressionStatement) {
195+
return ((ExpressionStatement) stmt).getExpression();
196+
}
197+
return null;
198+
}
199+
200+
private static Expression copyExpression(ClosureExpression closureExpression) {
201+
// Re-extract from the closure to get a fresh AST node
202+
// (the original is consumed by the first injection point)
203+
BlockStatement block = (BlockStatement) closureExpression.getCode();
204+
List<Statement> statements = block.getStatements();
205+
if (statements.size() != 1) return null;
206+
Statement stmt = statements.get(0);
207+
if (stmt instanceof ExpressionStatement exprStmt) {
208+
// Use transformExpression to get a deep copy
209+
return exprStmt.getExpression().transformExpression(expr -> expr);
210+
}
211+
return null;
212+
}
213+
214+
private static void injectAtLoopBodyStartAndEnd(LoopingStatement loopStatement,
215+
Statement startCheck, Statement endCheck) {
216+
Statement loopBody = loopStatement.getLoopBlock();
217+
BlockStatement newBody;
218+
if (loopBody instanceof BlockStatement block) {
219+
// Prepend save at start
220+
block.getStatements().add(0, startCheck);
221+
// Append checks at end
222+
block.getStatements().addAll(((BlockStatement) endCheck).getStatements());
223+
newBody = block;
224+
} else {
225+
newBody = new BlockStatement();
226+
newBody.addStatement(startCheck);
227+
newBody.addStatement(loopBody);
228+
newBody.addStatements(((BlockStatement) endCheck).getStatements());
229+
newBody.setSourcePosition(loopBody);
230+
loopStatement.setLoopBlock(newBody);
231+
}
232+
}
233+
}

0 commit comments

Comments
 (0)