diff --git a/Ix/Aiur/Compiler/Simple.lean b/Ix/Aiur/Compiler/Simple.lean index cac488c8..c2cf00be 100644 --- a/Ix/Aiur/Compiler/Simple.lean +++ b/Ix/Aiur/Compiler/Simple.lean @@ -27,6 +27,24 @@ Only valid when it doesn't shadow any other binding. Public so proofs in other modules can cite the definition (e.g., via unfolding). -/ abbrev tmpVar : Local := .idx 0 +/-- Build `let pat = v in b`, but first float any leading `let`s out of `v`: +`let pat = (let w = e in rest) in b` ⤳ `let w = e in let pat = rest in b`. + +The match compiler hoists a non-variable `match` scrutinee into a `let` +(`MatchCompiler.switch`), so `let x = match foo(bar) {..}` simplifies to +`let x = (let w = foo(bar) in match w {..})`. That buries the `match` one +`let` deep, where `Lower`'s non-tail-match detector (which only fires when a +`match` is the *immediate* `letVar`/`letWild` RHS) can't see it. Floating the +hoisted `let`s outward restores the invariant: the `match` becomes the direct +RHS again. The hoisted `w`s are fresh match-compiler locals, so widening their +scope over `b` cannot capture anything. -/ +def mkLetFloating (τ : Typ) (e : Bool) (pat : Pattern) (v b : Term) : Term := + match v with + | .let τ' e' pat' v' rest => .let τ' e' pat' v' (mkLetFloating τ e pat rest b) + | _ => .let τ e pat v b +termination_by sizeOf v +decreasing_by decreasing_tactic + /-- `simplifyTypedTerm` walks a typed term, producing a term of the same shape whose `match`es have been pre-compiled down to the decision-tree form, and whose `let`s bind only simple locals or wildcards. It operates in the `CheckError` @@ -38,11 +56,11 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term | .let τ e (.var x) v b => do let v' ← simplifyTypedTerm decls v let b' ← simplifyTypedTerm decls b - pure (.let τ e (.var x) v' b') + pure (mkLetFloating τ e (.var x) v' b') | .let τ e .wildcard v b => do let v' ← simplifyTypedTerm decls v let b' ← simplifyTypedTerm decls b - pure (.let τ e .wildcard v' b') + pure (mkLetFloating τ e .wildcard v' b') | .let τ e pat v b => do let v' ← simplifyTypedTerm decls v let b' ← simplifyTypedTerm decls b @@ -52,7 +70,7 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term match MatchCompiler.decisionToTyped b'.typ tmp.typ tree with | some rewrite => rewrite | none => .match b'.typ b'.escapes tmp [(pat, b')] - pure (.let τ e (.var tmpVar) v' body) + pure (mkLetFloating τ e (.var tmpVar) v' body) | .match τ e scrut branches => do let scrut' ← simplifyTypedTerm decls scrut let branches' ← branches.attach.mapM fun pb => diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index c7f39915..e735cd12 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -552,6 +552,15 @@ def toplevel := ⟦ x + 1 } + -- Non-tail match whose scrutinee is a function call (`let x = match foo(bar) {...}`). + -- The scrutinee is hoisted into a fresh let by the match compiler; the + -- continuation must still reach `matchContinue`. ntm_helper(x) = x*x+1, + -- so a=0 -> 1 -> 100 -> 101. + fn ntm_match_on_call(a: G) -> G { + let x = match ntm_helper(a) { 1 => 100, 5 => 200, _ => a * a, }; + x + 1 + } + -- Non-tail match with store/load in branches (lookup gating) fn ntm_store_load(a: G) -> G { let x = match a { 0 => load(store(42)), _ => load(store(a)), }; @@ -777,8 +786,10 @@ def toplevel := ⟦ let r19 = ntm_recursive_test(); -- Nested early return (yields 0, sum unchanged) let r20 = ntm_nested(0, 0); + -- Function-call scrutinee: 101 + 201 + 10 = 312 + let r21 = ntm_match_on_call(0) + ntm_match_on_call(2) + ntm_match_on_call(3); r1 + r2 + r3 + r4 + r5 + r6 + r7 + r8 + r9 + r10 - + r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 + + r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 + r21 } ⟧ @@ -935,8 +946,8 @@ def aiurTestCases : List AiurTestCase := [ .noIO `template_pair #[] #[10, 20], .noIO `template_nested #[] #[7], - -- Non-tail match: all patterns in one proof - .noIO `non_tail_match #[] #[2281], + -- Non-tail match: all patterns in one proof (incl. function-call scrutinee) + .noIO `non_tail_match #[] #[2593], ] end diff --git a/Tests/Aiur/Cross.lean b/Tests/Aiur/Cross.lean index fab13912..13488057 100644 --- a/Tests/Aiur/Cross.lean +++ b/Tests/Aiur/Cross.lean @@ -541,6 +541,15 @@ def toplevel : Source.Toplevel := ⟦ x + 1 } + -- Non-tail match whose scrutinee is a function call (`let x = match foo(bar) {...}`). + -- The scrutinee is hoisted into a fresh let by the match compiler; the + -- continuation must still reach `matchContinue`. ntm_helper(x) = x*x+1, + -- so a=0 -> 101, a=2 -> 201, a=3 -> 10. + fn ntm_match_on_call(a: G) -> G { + let x = match ntm_helper(a) { 1 => 100, 5 => 200, _ => a * a, }; + x + 1 + } + -- Pre-branch constant multiplied in a branch (no default) pub fn ntm_const_mul(a: G) -> G { let c = 5; @@ -1058,8 +1067,10 @@ def toplevel : Source.Toplevel := ⟦ let r18 = ntm_refutable_let_in_match(0); let r19 = ntm_recursive_test(); let r20 = ntm_nested(0, 0); + -- Function-call scrutinee: 101 + 201 + 10 = 312 + let r21 = ntm_match_on_call(0) + ntm_match_on_call(2) + ntm_match_on_call(3); r1 + r2 + r3 + r4 + r5 + r6 + r7 + r8 + r9 + r10 - + r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 + + r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 + r21 } ⟧