Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,71 @@ private static String qualifiedMethodName(Symbol.MethodSymbol sym) {

private class StaticMethodCollector extends TreeScanner {
final List<Procedure> procedures = new ArrayList<>();
private int labelCounter = 0;

/** Label stack entry for break/continue resolution. */
private record LabelEntry(String javaLabel, String breakLabel, String continueLabel) {}
private final Deque<LabelEntry> labelStack = new ArrayDeque<>();
private String pendingLabel = null;

/** Check if a statement tree contains break or continue (at the current loop level). */
private boolean containsBreakOrContinue(JCTree.JCStatement stmt) {
boolean[] found = {false};
stmt.accept(new TreeScanner() {
@Override public void visitBreak(JCTree.JCBreak tree) { found[0] = true; }
@Override public void visitContinue(JCTree.JCContinue tree) { found[0] = true; }
// Don't descend into nested loops — their break/continue don't target us
@Override public void visitWhileLoop(JCTree.JCWhileLoop tree) {}
@Override public void visitForLoop(JCTree.JCForLoop tree) {}
@Override public void visitDoLoop(JCTree.JCDoWhileLoop tree) {}
});
return found[0];
}

private String freshLabel(String prefix) {
return prefix + "_" + (labelCounter++);
}

/** Set up loop labels, push to stack, execute body, pop stack. */
private StmtExpr withLoopLabels(JCTree.JCStatement loopBody, java.util.function.BiFunction<String, String, StmtExpr> body) {
boolean needsLabels = containsBreakOrContinue(loopBody);
String breakLbl = needsLabels ? freshLabel("loop_break") : null;
String continueLbl = needsLabels ? freshLabel("loop_continue") : null;
String javaLabel = pendingLabel;
pendingLabel = null;
labelStack.push(new LabelEntry(javaLabel, breakLbl, continueLbl));
try {
return body.apply(breakLbl, continueLbl);
} finally {
labelStack.pop();
}
}

private String resolveBreakLabel(JCTree.JCBreak brk) {
if (brk.label != null) {
String target = brk.label.toString();
for (var entry : labelStack) {
if (target.equals(entry.javaLabel)) return entry.breakLabel;
}
} else if (!labelStack.isEmpty()) {
return labelStack.peek().breakLabel;
}
throw new JavaViolationException("break target not found");
}

private String resolveContinueLabel(JCTree.JCContinue cont) {
if (cont.label != null) {
String target = cont.label.toString();
for (var entry : labelStack) {
if (target.equals(entry.javaLabel) && entry.continueLabel != null) return entry.continueLabel;
}
} else {
for (var entry : labelStack) {
if (entry.continueLabel != null) return entry.continueLabel;
}
}
throw new JavaViolationException("continue target not found");
}

@Override
public void visitMethodDef(JCTree.JCMethodDecl method) {
Expand Down Expand Up @@ -233,6 +298,64 @@ private StmtExpr convertBlock(JCTree.JCBlock blk) {
return block(toSourceRange(blk), statements);
}

/**
* Emit a while loop with optional break/continue label wrapping.
*
* When labels are present, produces (for a for-loop with break/continue):
* <pre>
* labelledBlock(breakLbl, {
* preamble... // e.g. init for for-loops, sentinel decl for do-while
* while (cond)
* invariants: [...]
* {
* labelledBlock(continueLbl, { body });
* step; // null for while/do-while
* }
* })
* </pre>
* break emits as exit(breakLbl), continue as exit(continueLbl).
* The step is outside the continue label so continue skips the body but still runs the step.
* @param sr source range
* @param cond loop condition
* @param invariants loop invariants
* @param loopBody the translated loop body
* @param step optional step expression (for-loops); null for while/do-while
* @param preamble statements to emit before the while (e.g. init, sentinel decl)
* @param needsLabels whether break/continue is present
* @param breakLbl break label (null if !needsLabels)
* @param continueLbl continue label (null if !needsLabels)
*/
private StmtExpr emitLoop(SourceRange sr, StmtExpr cond, List<InvariantClause> invariants,
StmtExpr loopBody, StmtExpr step, List<StmtExpr> preamble,
String breakLbl, String continueLbl) {
if (breakLbl != null) {
StmtExpr wrappedBody = labelledBlock(sr, List.of(loopBody), continueLbl);
StmtExpr whileBody;
if (step != null) {
whileBody = block(sr, List.of(wrappedBody, step));
} else {
whileBody = wrappedBody;
}
StmtExpr whileNode = while_(sr, cond, invariants, whileBody);
List<StmtExpr> outerStmts = new ArrayList<>(preamble);
outerStmts.add(whileNode);
return labelledBlock(sr, outerStmts, breakLbl);
} else {
if (step != null) {
// Use forLoop IR node for simple for-loops without break/continue
StmtExpr init = preamble.isEmpty() ? block(sr, List.of()) : preamble.getFirst();
return forLoop(sr, init, cond, step, invariants, loopBody);
}
StmtExpr whileNode = while_(sr, cond, invariants, loopBody);
if (preamble.isEmpty()) {
return whileNode;
}
List<StmtExpr> stmts = new ArrayList<>(preamble);
stmts.add(whileNode);
return block(sr, stmts);
}
}

private StmtExpr convertStatement(JCTree.JCStatement statement) {
return switch (statement) {
case JCTree.JCAssert assertStmt ->
Expand Down Expand Up @@ -266,34 +389,77 @@ yield varDecl(toSourceRange(varDecl), varDecl.name.toString(),
yield null;
}
case JCTree.JCWhileLoop whileStmt -> {
StmtExpr cond = convertExpression(whileStmt.cond);
if (whileStmt.body instanceof JCTree.JCBlock loopBlock) {
var parts = extractLoopParts(loopBlock);
yield while_(toSourceRange(whileStmt), cond, parts.invariants, parts.body);
}
yield while_(toSourceRange(whileStmt), cond, List.of(), convertStatement(whileStmt.body));
var sr = toSourceRange(whileStmt);
yield withLoopLabels(whileStmt.body, (breakLbl, continueLbl) -> {
StmtExpr cond = convertExpression(whileStmt.cond);
var parts = extractLoopParts(whileStmt.body);
return emitLoop(sr, cond, parts.invariants, parts.body, null, List.of(),
breakLbl, continueLbl);
});
}
case JCTree.JCForLoop forLoop -> {
if (forLoop.init.size() > 1 || forLoop.step.size() > 1)
throw new JavaViolationException("Multi-init or multi-step for loops are not supported");
StmtExpr init = forLoop.init.isEmpty() ? block(toSourceRange(forLoop), List.of())
: convertStatement(forLoop.init.getFirst());
StmtExpr cond = forLoop.cond != null
? convertExpression(forLoop.cond)
: literalBool(toSourceRange(forLoop), true);
StmtExpr step = forLoop.step.isEmpty() ? block(toSourceRange(forLoop), List.of())
: convertStatement(forLoop.step.getFirst());
List<InvariantClause> invariants = List.of();
StmtExpr loopBody;
if (forLoop.body instanceof JCTree.JCBlock loopBlock) {
var parts = extractLoopParts(loopBlock);
invariants = parts.invariants;
loopBody = parts.body;
var sr = toSourceRange(forLoop);
yield withLoopLabels(forLoop.body, (breakLbl, continueLbl) -> {
StmtExpr init = forLoop.init.isEmpty() ? block(sr, List.of())
: convertStatement(forLoop.init.getFirst());
StmtExpr cond = forLoop.cond != null
? convertExpression(forLoop.cond)
: literalBool(sr, true);
StmtExpr step = forLoop.step.isEmpty() ? block(sr, List.of())
: convertStatement(forLoop.step.getFirst());
var parts = extractLoopParts(forLoop.body);
return emitLoop(sr, cond, parts.invariants, parts.body, step, List.of(init),
breakLbl, continueLbl);
});
}
case JCTree.JCDoWhileLoop doWhileStmt -> {
var sr = toSourceRange(doWhileStmt);
yield withLoopLabels(doWhileStmt.body, (breakLbl, continueLbl) -> {
// Desugar: var __first_N = true; while(__first_N || cond) { __first_N = false; body }
String sentinel = "__first_" + labelCounter;
StmtExpr sentinelDecl = varDecl(sr, sentinel,
Optional.of(typeAnnotation(boolType())),
Optional.of(initializer(literalBool(sr, true))));
StmtExpr cond = orElse(sr, identifier(sr, sentinel), convertExpression(doWhileStmt.cond));
StmtExpr sentinelReset = assign(sr, identifier(sr, sentinel), literalBool(sr, false));

var parts = extractLoopParts(doWhileStmt.body);
StmtExpr loopBody = block(sr, List.of(sentinelReset, parts.body));
return emitLoop(sr, cond, parts.invariants, loopBody, null, List.of(sentinelDecl),
breakLbl, continueLbl);
});
}
case JCTree.JCBreak breakStmt -> {
yield exit(toSourceRange(breakStmt), resolveBreakLabel(breakStmt));
}
case JCTree.JCContinue contStmt -> {
yield exit(toSourceRange(contStmt), resolveContinueLabel(contStmt));
}
case JCTree.JCLabeledStatement labeledStmt -> {
// Set the pending label so the next loop picks it up
pendingLabel = labeledStmt.label.toString();
// If the body is a loop, it will consume pendingLabel
// If not, wrap in a labelledBlock for break-out-of-block
if (labeledStmt.body instanceof JCTree.JCWhileLoop
|| labeledStmt.body instanceof JCTree.JCForLoop
|| labeledStmt.body instanceof JCTree.JCDoWhileLoop) {
yield convertStatement(labeledStmt.body);
} else {
loopBody = convertStatement(forLoop.body);
// Non-loop labeled statement: only break is valid (no continue)
String breakLbl = freshLabel("label_break");
labelStack.push(new LabelEntry(pendingLabel, breakLbl, null));
pendingLabel = null;
try {
StmtExpr body = convertStatement(labeledStmt.body);
yield labelledBlock(toSourceRange(labeledStmt), List.of(body), breakLbl);
} finally {
labelStack.pop();
}
}
yield forLoop(toSourceRange(forLoop), init, cond, step, invariants, loopBody);
}
case JCTree.JCSkip ignored -> null;
default -> throw new JavaViolationException("Unsupported statement: " + statement.getClass().getSimpleName());
};
}
Expand All @@ -316,19 +482,23 @@ private StmtExpr convertLambdaBody(JCTree.JCLambda lambda, Map<String, String> r

private record LoopParts(List<InvariantClause> invariants, StmtExpr body) {}

private LoopParts extractLoopParts(JCTree.JCBlock loopBlock) {
MethodOrLoopContract loopContract = contractCompiler.getContract(loopBlock);
List<InvariantClause> invariants = new ArrayList<>();
for (var inv : loopContract.loopInvariants()) {
invariants.add(invariantClause(convertExpression(inv.get())));
}
var implStatements = MethodOrLoopContractCompiler.getImplementationStatements(loopBlock);
List<StmtExpr> stmts = new ArrayList<>();
for (var s : implStatements) {
StmtExpr converted = convertStatement(s);
if (converted != null) stmts.add(converted);
private LoopParts extractLoopParts(JCTree.JCStatement body) {
if (body instanceof JCTree.JCBlock loopBlock) {
MethodOrLoopContract loopContract = contractCompiler.getContract(loopBlock);
List<InvariantClause> invariants = new ArrayList<>();
for (var inv : loopContract.loopInvariants()) {
invariants.add(invariantClause(convertExpression(inv.get())));
}
var implStatements = MethodOrLoopContractCompiler.getImplementationStatements(loopBlock);
List<StmtExpr> stmts = new ArrayList<>();
for (var s : implStatements) {
StmtExpr converted = convertStatement(s);
if (converted != null) stmts.add(converted);
}
return new LoopParts(invariants, block(toSourceRange(loopBlock), stmts));
} else {
return new LoopParts(List.of(), convertStatement(body));
}
return new LoopParts(invariants, block(toSourceRange(loopBlock), stmts));
}

private StmtExpr convertExpression(JCTree.JCExpression expr, Map<String, String> renames) {
Expand Down
Loading
Loading