Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions Ix/Aiur/Compiler/Simple.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ Only valid when it doesn't shadow any other binding. Public so proofs in other
modules can cite the definition (e.g., via unfolding). -/
abbrev tmpVar : Local := .idx 0

/-- Build `let pat = v in b`, but first float any leading `let`s out of `v`:
`let pat = (let w = e in rest) in b` ⤳ `let w = e in let pat = rest in b`.

The match compiler hoists a non-variable `match` scrutinee into a `let`
(`MatchCompiler.switch`), so `let x = match foo(bar) {..}` simplifies to
`let x = (let w = foo(bar) in match w {..})`. That buries the `match` one
`let` deep, where `Lower`'s non-tail-match detector (which only fires when a
`match` is the *immediate* `letVar`/`letWild` RHS) can't see it. Floating the
hoisted `let`s outward restores the invariant: the `match` becomes the direct
RHS again. The hoisted `w`s are fresh match-compiler locals, so widening their
scope over `b` cannot capture anything. -/
def mkLetFloating (τ : Typ) (e : Bool) (pat : Pattern) (v b : Term) : Term :=
match v with
| .let τ' e' pat' v' rest => .let τ' e' pat' v' (mkLetFloating τ e pat rest b)
| _ => .let τ e pat v b
termination_by sizeOf v
decreasing_by decreasing_tactic

/-- `simplifyTypedTerm` walks a typed term, producing a term of the same shape
whose `match`es have been pre-compiled down to the decision-tree form, and whose
`let`s bind only simple locals or wildcards. It operates in the `CheckError`
Expand All @@ -38,11 +56,11 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term
| .let τ e (.var x) v b => do
let v' ← simplifyTypedTerm decls v
let b' ← simplifyTypedTerm decls b
pure (.let τ e (.var x) v' b')
pure (mkLetFloating τ e (.var x) v' b')
| .let τ e .wildcard v b => do
let v' ← simplifyTypedTerm decls v
let b' ← simplifyTypedTerm decls b
pure (.let τ e .wildcard v' b')
pure (mkLetFloating τ e .wildcard v' b')
| .let τ e pat v b => do
let v' ← simplifyTypedTerm decls v
let b' ← simplifyTypedTerm decls b
Expand All @@ -52,7 +70,7 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term
match MatchCompiler.decisionToTyped b'.typ tmp.typ tree with
| some rewrite => rewrite
| none => .match b'.typ b'.escapes tmp [(pat, b')]
pure (.let τ e (.var tmpVar) v' body)
pure (mkLetFloating τ e (.var tmpVar) v' body)
| .match τ e scrut branches => do
let scrut' ← simplifyTypedTerm decls scrut
let branches' ← branches.attach.mapM fun pb =>
Expand Down
17 changes: 14 additions & 3 deletions Tests/Aiur/Aiur.lean
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,15 @@ def toplevel := ⟦
x + 1
}

-- Non-tail match whose scrutinee is a function call (`let x = match foo(bar) {...}`).
-- The scrutinee is hoisted into a fresh let by the match compiler; the
-- continuation must still reach `matchContinue`. ntm_helper(x) = x*x+1,
-- so a=0 -> 1 -> 100 -> 101.
fn ntm_match_on_call(a: G) -> G {
let x = match ntm_helper(a) { 1 => 100, 5 => 200, _ => a * a, };
x + 1
}

-- Non-tail match with store/load in branches (lookup gating)
fn ntm_store_load(a: G) -> G {
let x = match a { 0 => load(store(42)), _ => load(store(a)), };
Expand Down Expand Up @@ -777,8 +786,10 @@ def toplevel := ⟦
let r19 = ntm_recursive_test();
-- Nested early return (yields 0, sum unchanged)
let r20 = ntm_nested(0, 0);
-- Function-call scrutinee: 101 + 201 + 10 = 312
let r21 = ntm_match_on_call(0) + ntm_match_on_call(2) + ntm_match_on_call(3);
r1 + r2 + r3 + r4 + r5 + r6 + r7 + r8 + r9 + r10
+ r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20
+ r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 + r21
}

Expand Down Expand Up @@ -935,8 +946,8 @@ def aiurTestCases : List AiurTestCase := [
.noIO `template_pair #[] #[10, 20],
.noIO `template_nested #[] #[7],

-- Non-tail match: all patterns in one proof
.noIO `non_tail_match #[] #[2281],
-- Non-tail match: all patterns in one proof (incl. function-call scrutinee)
.noIO `non_tail_match #[] #[2593],
]

end
13 changes: 12 additions & 1 deletion Tests/Aiur/Cross.lean
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,15 @@ def toplevel : Source.Toplevel := ⟦
x + 1
}

-- Non-tail match whose scrutinee is a function call (`let x = match foo(bar) {...}`).
-- The scrutinee is hoisted into a fresh let by the match compiler; the
-- continuation must still reach `matchContinue`. ntm_helper(x) = x*x+1,
-- so a=0 -> 101, a=2 -> 201, a=3 -> 10.
fn ntm_match_on_call(a: G) -> G {
let x = match ntm_helper(a) { 1 => 100, 5 => 200, _ => a * a, };
x + 1
}

-- Pre-branch constant multiplied in a branch (no default)
pub fn ntm_const_mul(a: G) -> G {
let c = 5;
Expand Down Expand Up @@ -1058,8 +1067,10 @@ def toplevel : Source.Toplevel := ⟦
let r18 = ntm_refutable_let_in_match(0);
let r19 = ntm_recursive_test();
let r20 = ntm_nested(0, 0);
-- Function-call scrutinee: 101 + 201 + 10 = 312
let r21 = ntm_match_on_call(0) + ntm_match_on_call(2) + ntm_match_on_call(3);
r1 + r2 + r3 + r4 + r5 + r6 + r7 + r8 + r9 + r10
+ r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20
+ r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 + r21
}

Expand Down