From 5b872eed03ab2ee56852f9b8e3c2f60c0742a1ce Mon Sep 17 00:00:00 2001 From: Nada Amin Date: Wed, 24 Jun 2026 01:02:12 -0400 Subject: [PATCH 1/2] revive contributed example from #77; will need to be brought up to date and fixed on local mac, reason it was deleted in 1bf5f93 --- examples/countBadPairs.dfy | 607 +++++++++++++++++++++++++++++++++ examples/countBadPairs.dfy.gen | 154 +++++++++ examples/countBadPairs.ts | 165 +++++++++ 3 files changed, 926 insertions(+) create mode 100644 examples/countBadPairs.dfy create mode 100644 examples/countBadPairs.dfy.gen create mode 100644 examples/countBadPairs.ts diff --git a/examples/countBadPairs.dfy b/examples/countBadPairs.dfy new file mode 100644 index 0000000..fa66ff3 --- /dev/null +++ b/examples/countBadPairs.dfy @@ -0,0 +1,607 @@ +// Generated by lsc from countBadPairs.ts +// +// Hand-written proof scaffolding for the strong post-condition +// `count == |badPairsImpl(nums)|`. Adapted from +// https://github.com/hath995/Dafny4.4/blob/main/leetcode/CountBadPairs.dfy +// +// Pair-set predicates are defined in terms of the user-visible +// `Pair { i, j }` datatype (gen'd from the TS interface), so the proof +// directly references the same type used in the impl functions. +// Equivalence lemmas bridge `xImpl(nums) == X(nums)` since each impl's +// spec body delegates to the matching hand-written predicate. + +datatype Pair = Pair(i: nat, j: nat) + +// ═════════════════════ Iterative impls (function by method) ═════════════════════ +// Spec body delegates to the matching hand-written predicate, so the +// `by method` block verifies the imperative impl returns that set. + +function allPairsImpl(nums: seq): set +{ + AllPairs(nums) +} +by method { + var result: set := {}; + var i := 0; + while (i < |nums|) + invariant (0 <= i) + invariant (i <= |nums|) + invariant forall p: Pair :: ((p in result) ==> (((p.i < i) && (p.j < |nums|)) && (p.i < p.j))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> (Pair(x, y) in result)) + decreases (|nums| - i) + { + var j := (i + 1); + while (j < |nums|) + invariant ((i + 1) <= j) + invariant (j <= |nums|) + invariant forall p: Pair :: ((p in result) ==> ((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) || (((p.i == i) && (p.i < p.j)) && (p.j < j)))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> (Pair(x, y) in result)) + invariant forall y: nat :: ((i < y) ==> (y < j) ==> (Pair(i, y) in result)) + decreases (|nums| - j) + { + result := (result + {Pair(i, j)}); + j := (j + 1); + } + i := (i + 1); + } + return result; +} + +function goodPairsImpl(nums: seq): set +{ + GoodPairs(nums) +} +by method { + var result: set := {}; + var i := 0; + while (i < |nums|) + invariant (0 <= i) + invariant (i <= |nums|) + invariant forall p: Pair :: ((p in result) ==> ((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) == (p.j - p.i)))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) == (y - x)) ==> (Pair(x, y) in result)) + decreases (|nums| - i) + { + var j := (i + 1); + while (j < |nums|) + invariant ((i + 1) <= j) + invariant (j <= |nums|) + invariant forall p: Pair :: ((p in result) ==> (((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) == (p.j - p.i))) || ((((p.i == i) && (p.i < p.j)) && (p.j < j)) && ((nums[p.j] - nums[p.i]) == (p.j - p.i))))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) == (y - x)) ==> (Pair(x, y) in result)) + invariant forall y: nat :: ((i < y) ==> (y < j) ==> ((nums[y] - nums[i]) == (y - i)) ==> (Pair(i, y) in result)) + decreases (|nums| - j) + { + if ((nums[j] - nums[i]) == (j - i)) { + result := (result + {Pair(i, j)}); + } + j := (j + 1); + } + i := (i + 1); + } + return result; +} + +function badPairsImpl(nums: seq): set +{ + BadPairs(nums) +} +by method { + var result: set := {}; + var i := 0; + while (i < |nums|) + invariant (0 <= i) + invariant (i <= |nums|) + invariant forall p: Pair :: ((p in result) ==> ((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) != (p.j - p.i)))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) != (y - x)) ==> (Pair(x, y) in result)) + decreases (|nums| - i) + { + var j := (i + 1); + while (j < |nums|) + invariant ((i + 1) <= j) + invariant (j <= |nums|) + invariant forall p: Pair :: ((p in result) ==> (((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) != (p.j - p.i))) || ((((p.i == i) && (p.i < p.j)) && (p.j < j)) && ((nums[p.j] - nums[p.i]) != (p.j - p.i))))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) != (y - x)) ==> (Pair(x, y) in result)) + invariant forall y: nat :: ((i < y) ==> (y < j) ==> ((nums[y] - nums[i]) != (y - i)) ==> (Pair(i, y) in result)) + decreases (|nums| - j) + { + if ((nums[j] - nums[i]) != (j - i)) { + result := (result + {Pair(i, j)}); + } + j := (j + 1); + } + i := (i + 1); + } + return result; +} + +// ═════════════════════ Pair-set predicates ═════════════════════ + +function AllPairs(nums: seq): set +{ + set i: nat, j: nat | i < |nums| && j < |nums| && i < j :: Pair(i, j) +} + +function IncrementPairs(pairs: set): set +{ + set p | p in pairs :: Pair(p.i + 1, p.j + 1) +} + +function ZeroPairs(length: nat): set +{ + set y: nat | 1 <= y <= length :: Pair(0, y) +} + +function GoodPairs(nums: seq): set +{ + set i: nat, j: nat | i < |nums| && j < |nums| && i < j && j - i == nums[j] - nums[i] :: Pair(i, j) +} + +function GoodPairsI(nums: seq, k: nat): set + requires k <= |nums| +{ + set x: nat, y: nat | x < k && y < |nums| && x < y && y - x == nums[y] - nums[x] :: Pair(x, y) +} + +function GoodPairsIK(nums: seq, idx: nat, k: nat): set + requires idx < |nums| + requires idx < k <= |nums| +{ + set y: nat | y < k && idx < y && y - idx == nums[y] - nums[idx] :: Pair(idx, y) +} + +function GoodPairsII(nums: seq, k: nat): set + requires k <= |nums| +{ + set x: nat, y: nat | x < k && y < k && x < y && y - x == nums[y] - nums[x] :: Pair(x, y) +} + +function BadPairs(nums: seq): set +{ + set i: nat, j: nat | i < |nums| && j < |nums| && i < j && j - i != nums[j] - nums[i] :: Pair(i, j) +} + +function DiffsSet(nums: seq): set +{ + set x: nat | x < |nums| :: nums[x] - x +} + +function IndicesCoset(nums: seq, diff: int): set + ensures forall i :: i in IndicesCoset(nums, diff) ==> 0 <= i < |nums| +{ + set i: nat | i < |nums| && nums[i] - i == diff :: i +} + +function CosetToPairInPlusOne(coset: set, nums: seq, i: nat): set + requires forall x :: x in coset ==> x < i + ensures forall p :: p in CosetToPairInPlusOne(coset, nums, i) ==> p.i < p.j +{ + set x: nat | x in coset && x < i :: Pair(x, i) +} + +// ═════════════════════ Impl ≡ predicate (trivial by spec-body delegation) ═════════════════════ + +lemma AllPairsImplEqAllPairs(nums: seq) + ensures allPairsImpl(nums) == AllPairs(nums) +{ } + +lemma GoodPairsImplEqGoodPairs(nums: seq) + ensures goodPairsImpl(nums) == GoodPairs(nums) +{ } + +lemma BadPairsImplEqBadPairs(nums: seq) + ensures badPairsImpl(nums) == BadPairs(nums) +{ } + +// ═════════════════════ Supporting lemmas ═════════════════════ + +lemma IncrementPairsSize(pairs: set) + ensures |IncrementPairs(pairs)| == |pairs| +{ + if pairs == {} { + } else { + var x :| x in pairs; + IncrementPairsSize(pairs - {x}); + assert IncrementPairs(pairs) == IncrementPairs(pairs - {x}) + {Pair(x.i + 1, x.j + 1)}; + } +} + +lemma ZeroPairsSize(length: nat) + ensures |ZeroPairs(length)| == length +{ + if length == 0 { + } else if length == 1 { + assert ZeroPairs(length) == {Pair(0, 1)}; + } else { + ZeroPairsSize(length - 1); + assert ZeroPairs(length) == ZeroPairs(length - 1) + {Pair(0, length)}; + } +} + +lemma AllPairsEqual(nums: seq) + requires |nums| > 0 + ensures AllPairs(nums) == ZeroPairs(|nums| - 1) + IncrementPairs(AllPairs(nums[1..])) +{ + if |nums| == 1 { + } else { + forall p | p in AllPairs(nums) + ensures p in ZeroPairs(|nums| - 1) || p in IncrementPairs(AllPairs(nums[1..])) + { + if p.i == 0 { + } else { + assert Pair(p.i - 1, p.j - 1) in AllPairs(nums[1..]); + } + } + } +} + +lemma SetSizes(s1: set, s2: set) + requires s1 <= s2 + ensures |s1| <= |s2| +{ + if s1 != {} { + var x :| x in s1; + SetSizes(s1 - {x}, s2 - {x}); + } +} + +lemma {:vcs_split_on_every_assert} AllPairsSize(nums: seq) + ensures |AllPairs(nums)| == |nums| * (|nums| - 1) / 2 +{ + if |nums| <= 1 { + assert |AllPairs(nums)| == 0; + } else { + assert nums == [nums[0]] + nums[1..]; + AllPairsSize(nums[1..]); + AllPairsEqual(nums); + IncrementPairsSize(AllPairs(nums[1..])); + ZeroPairsSize(|nums| - 1); + assert ZeroPairs(|nums| - 1) !! IncrementPairs(AllPairs(nums[1..])); + calc { + |AllPairs(nums)|; + |ZeroPairs(|nums| - 1)| + |IncrementPairs(AllPairs(nums[1..]))|; + |nums| - 1 + |IncrementPairs(AllPairs(nums[1..]))|; + |nums| - 1 + (|nums| - 1) * (|nums| - 2) / 2; + |nums| - 1 + (|nums| * |nums| - 3 * |nums| + 2) / 2; + 2 * (|nums| - 1) / 2 + (|nums| * |nums| - 3 * |nums| + 2) / 2; + (2 * |nums| - 2) / 2 + (|nums| * |nums| - 3 * |nums| + 2) / 2; + (2 * |nums| - 2 + |nums| * |nums| - 3 * |nums| + 2) / 2; + (|nums| * |nums| - 1 * |nums|) / 2; + |nums| * (|nums| - 1) / 2; + } + } +} + +lemma GoodPairsIkNextPos(nums: seq, idx: nat, k: nat) + requires idx < |nums| + requires idx < k < |nums| + requires nums[k] - nums[idx] == k - idx + ensures GoodPairsIK(nums, idx, k + 1) == GoodPairsIK(nums, idx, k) + {Pair(idx, k)} +{ } + +lemma GoodPairsIkNextNeg(nums: seq, idx: nat, k: nat) + requires idx < |nums| + requires idx < k < |nums| + requires nums[k] - nums[idx] != k - idx + ensures GoodPairsIK(nums, idx, k + 1) == GoodPairsIK(nums, idx, k) +{ } + +lemma GoodPairsInext(nums: seq, idx: nat) + requires idx < |nums| + ensures GoodPairsI(nums, idx + 1) == GoodPairsI(nums, idx) + GoodPairsIK(nums, idx, |nums|) +{ } + +lemma GoodPairsIIEqGoodPairsAtN(nums: seq) + ensures GoodPairsII(nums, |nums|) == GoodPairs(nums) +{ } + +lemma BadPairsEqualsAllMinusGood(nums: seq) + requires |nums| > 0 + ensures BadPairs(nums) == AllPairs(nums) - GoodPairs(nums) +{ } + +lemma BadPairsSize(nums: seq) + requires |nums| > 0 + ensures |BadPairs(nums)| == |AllPairs(nums)| - |GoodPairs(nums)| +{ + BadPairsEqualsAllMinusGood(nums); +} + +lemma GoodPairsLessThanAll(nums: seq) + requires |nums| > 0 + ensures GoodPairs(nums) <= AllPairs(nums) +{ } + +lemma GoodPairsSize(nums: seq) + requires |nums| > 0 + ensures |GoodPairs(nums)| <= |AllPairs(nums)| +{ + GoodPairsLessThanAll(nums); + SetSizes(GoodPairs(nums), AllPairs(nums)); +} + +// ─── Diff-bucket scaffolding ─── + +lemma nextIndices(nums: seq, diff: int, i: nat) + requires i < |nums| + requires diff == nums[i] - i + requires diff !in DiffsSet(nums[0..i]) + ensures forall xdiff :: xdiff in DiffsSet(nums[0..i]) ==> + IndicesCoset(nums[0..i + 1], xdiff) == IndicesCoset(nums[0..i], xdiff) +{ + forall xdiff | xdiff in DiffsSet(nums[0..i]) + ensures IndicesCoset(nums[0..i + 1], xdiff) == IndicesCoset(nums[0..i], xdiff) + { } +} + +lemma nextIndicesIn(nums: seq, diff: int, i: nat) + requires i < |nums| + requires diff == nums[i] - i + requires diff in DiffsSet(nums[0..i]) + ensures forall xdiff :: xdiff in DiffsSet(nums[0..i]) && xdiff != diff ==> + IndicesCoset(nums[0..i + 1], xdiff) == IndicesCoset(nums[0..i], xdiff) +{ + forall xdiff | xdiff in DiffsSet(nums[0..i]) && xdiff != diff + ensures IndicesCoset(nums[0..i + 1], xdiff) == IndicesCoset(nums[0..i], xdiff) + { } +} + +lemma NotIndices(nums: seq, diff: int, i: nat) + requires i < |nums| + requires diff == nums[i] - i + requires diff !in DiffsSet(nums[0..i]) + ensures IndicesCoset(nums[0..i], diff) == {} +{ } + +lemma IndicesCosetElementsLessThanI(nums: seq, i: nat, diff: int) + requires i <= |nums| + ensures forall x :: x in IndicesCoset(nums[0..i], diff) ==> 0 <= x < i <= |nums| +{ } + +lemma CosetToPairPlusOne(coset: set, nums: seq, i: nat) + requires forall x :: x in coset ==> x < i < |nums| + ensures |CosetToPairInPlusOne(coset, nums, i)| == |coset| +{ + if coset != {} { + var x :| x in coset; + CosetToPairPlusOne(coset - {x}, nums, i); + assert Pair(x, i) !in CosetToPairInPlusOne(coset - {x}, nums, i); + assert CosetToPairInPlusOne(coset, nums, i) == + CosetToPairInPlusOne(coset - {x}, nums, i) + {Pair(x, i)}; + } +} + +lemma DiffMapKeysPos(nums: seq, diffMap: map, i: nat) + requires |nums| > 0 + requires i < |nums| + requires diffMap.Keys == DiffsSet(nums[0..i]) + requires nums[i] - i in diffMap + ensures diffMap[(nums[i] - i) := diffMap[nums[i] - i] + 1].Keys == DiffsSet(nums[0..i + 1]) +{ + var k :| k in diffMap.Keys && k == nums[i] - i; + assert k in DiffsSet(nums[0..i]); +} + +lemma DiffMapKeysNeg(nums: seq, diffMap: map, i: nat) + requires |nums| > 0 + requires i < |nums| + requires diffMap.Keys == DiffsSet(nums[0..i]) + requires nums[i] - i !in diffMap + ensures diffMap[(nums[i] - i) := 1].Keys == DiffsSet(nums[0..i + 1]) +{ } + +lemma IndicesCosetsContinuedNeg(nums: seq, i: nat, + diffCosets: map>, diffMap: map, diff: int) + requires i < |nums| + requires diff == nums[i] - i + requires forall d :: d in DiffsSet(nums[0..i]) ==> d in diffMap + requires forall d :: d in DiffsSet(nums[0..i]) ==> + diffMap[d] == |IndicesCoset(nums[0..i], d)| + requires forall d :: d in DiffsSet(nums[0..i]) ==> d in diffCosets + requires forall d :: d in DiffsSet(nums[0..i]) ==> + diffCosets[d] == IndicesCoset(nums[0..i], d) + requires diff !in diffMap + ensures forall ldiff {:trigger ldiff in DiffsSet(nums[0..i + 1])} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> ldiff in diffCosets[diff := {i}] + ensures forall ldiff {:trigger IndicesCoset(nums[0..i + 1], ldiff)} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> + diffCosets[diff := {i}][ldiff] == IndicesCoset(nums[0..i + 1], ldiff) + ensures forall ldiff {:trigger ldiff in diffMap[diff := 1]} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> ldiff in diffMap[diff := 1] + ensures forall ldiff {:trigger diffMap[diff := 1][ldiff]} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> + diffMap[diff := 1][ldiff] == |IndicesCoset(nums[0..i + 1], ldiff)| +{ + nextIndices(nums, diff, i); +} + +lemma IndicesCosetsContinuedPos(nums: seq, i: nat, + diffCosets: map>, diffMap: map, diff: int) + requires i < |nums| + requires diff == nums[i] - i + requires diff in DiffsSet(nums[..i]) + requires forall d :: d in DiffsSet(nums[0..i]) ==> d in diffMap + requires forall d :: d in DiffsSet(nums[0..i]) ==> + diffMap[d] == |IndicesCoset(nums[0..i], d)| + requires forall d :: d in DiffsSet(nums[0..i]) ==> d in diffCosets + requires forall d :: d in DiffsSet(nums[0..i]) ==> + diffCosets[d] == IndicesCoset(nums[0..i], d) + requires diff in diffMap + requires diff in diffCosets + ensures forall ldiff {:trigger ldiff in DiffsSet(nums[0..i + 1])} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> + ldiff in diffCosets[diff := diffCosets[diff] + {i}] + ensures forall ldiff {:trigger IndicesCoset(nums[0..i + 1], ldiff)} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> + diffCosets[diff := diffCosets[diff] + {i}][ldiff] == IndicesCoset(nums[0..i + 1], ldiff) + ensures forall ldiff {:trigger ldiff in diffMap[diff := diffMap[diff] + 1]} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> + ldiff in diffMap[diff := diffMap[diff] + 1] + ensures forall ldiff {:trigger diffMap[diff := diffMap[diff] + 1][ldiff]} :: + ldiff in DiffsSet(nums[0..i + 1]) ==> + diffMap[diff := diffMap[diff] + 1][ldiff] == |IndicesCoset(nums[0..i + 1], ldiff)| +{ + nextIndicesIn(nums, diff, i); + assert IndicesCoset(nums[0..i + 1], diff) == IndicesCoset(nums[0..i], diff) + {i}; + forall ldiff {:trigger IndicesCoset(nums[0..i + 1], ldiff)} | ldiff in DiffsSet(nums[0..i + 1]) + ensures diffCosets[diff := diffCosets[diff] + {i}][ldiff] == + IndicesCoset(nums[0..i + 1], ldiff) + { + if ldiff == diff { + assert diffCosets[diff] + {i} == IndicesCoset(nums[0..i], diff) + {i}; + } + } +} + +lemma goodPairsIINegContinued(nums: seq, i: nat, diffMap: map, + diffCosets: map>, diff: int, goodPairs: set) + requires i < |nums| + requires diff == nums[i] - i + requires diff !in DiffsSet(nums[0..i]) + requires diff !in diffMap + requires diff !in diffCosets + ensures GoodPairsII(nums, i) == GoodPairsII(nums, i + 1) +{ } + +lemma GoodPairsIIPosContinued(nums: seq, i: nat, goodCount: int, + diffMap: map, diffCosets: map>, diff: int, + goodPairs: set) + requires i < |nums| + requires diff == nums[i] - i + requires diff in DiffsSet(nums[0..i]) + requires diff in diffMap + requires diff in diffCosets + requires DiffsSet(nums[0..i]) == diffMap.Keys + requires diffMap.Keys == diffCosets.Keys + requires forall d :: d in DiffsSet(nums[0..i]) ==> + diffCosets[d] == IndicesCoset(nums[0..i], d) + requires forall d :: d in DiffsSet(nums[0..i]) ==> + diffMap[d] == |IndicesCoset(nums[0..i], d)| + requires goodCount == |GoodPairsII(nums, i)| + requires forall x :: x in diffCosets[diff] ==> x < i + ensures GoodPairsII(nums, i + 1) == GoodPairsII(nums, i) + + CosetToPairInPlusOne(diffCosets[diff], nums, i) + ensures goodCount + diffMap[diff] == |GoodPairsII(nums, i + 1)| +{ + CosetToPairPlusOne(diffCosets[diff], nums, i); + assert diffMap[diff] == |CosetToPairInPlusOne(diffCosets[diff], nums, i)|; +} + +// ═════════════════════ Generated methods (hand-annotated) ═════════════════════ + +method countBadPairsNaive(nums: seq) returns (res: int) + requires (|nums| > 0) + ensures (res >= 0) + ensures (res == |badPairsImpl(nums)|) +{ + var count := 0; + var n := |nums|; + ghost var pairsI: set := {}; + var i := 0; + while (i < (n - 1)) + invariant (0 <= i) + invariant (i <= (n - 1)) + invariant (count >= 0) + invariant pairsI == GoodPairsI(nums, i) + invariant count == |pairsI| + decreases ((n - 1) - i) + { + var iprime := (nums[i] - i); + ghost var oldCount := count; + ghost var pairsIK: set := {}; + var k := (i + 1); + while (k < n) + invariant ((i + 1) <= k) + invariant (k <= n) + invariant (count >= 0) + invariant pairsIK == GoodPairsIK(nums, i, k) + invariant count == oldCount + |pairsIK| + decreases (n - k) + { + if ((nums[k] - k) == iprime) { + GoodPairsIkNextPos(nums, i, k); + pairsIK := pairsIK + {Pair(i, k)}; + count := (count + 1); + } + if ((nums[k] - k) != iprime) { + GoodPairsIkNextNeg(nums, i, k); + } + k := (k + 1); + } + pairsI := pairsI + pairsIK; + GoodPairsInext(nums, i); + i := (i + 1); + } + assert i == n - 1; + assert GoodPairsIK(nums, n - 1, n) == {}; + GoodPairsInext(nums, n - 1); + assert GoodPairsI(nums, n) == GoodPairsI(nums, n - 1); + GoodPairsIIEqGoodPairsAtN(nums); + assert GoodPairs(nums) == GoodPairsI(nums, n); + BadPairsSize(nums); + GoodPairsSize(nums); + AllPairsSize(nums); + BadPairsImplEqBadPairs(nums); + var pairs := ((((n - 1) * n) / 2) - count); + return pairs; +} + +method countBadPairs(nums: seq) returns (res: int) + requires (|nums| > 0) + ensures (res >= 0) + ensures (res == |badPairsImpl(nums)|) +{ + var goodCount := 0; + var n := |nums|; + var diffMap: map := map[]; + ghost var goodPairs: set := {}; + ghost var diffCosets: map> := map[]; + var i := 0; + while (i < n) + invariant (0 <= i) + invariant (i <= n) + invariant (goodCount >= 0) + invariant DiffsSet(nums[0..i]) == diffMap.Keys + invariant diffCosets.Keys == diffMap.Keys + invariant forall d :: d in DiffsSet(nums[0..i]) ==> + diffCosets[d] == IndicesCoset(nums[0..i], d) + invariant forall d :: d in DiffsSet(nums[0..i]) ==> + diffMap[d] == |IndicesCoset(nums[0..i], d)| + invariant goodPairs == GoodPairsII(nums, i) + invariant goodCount == |goodPairs| + decreases (n - i) + { + var diff := (nums[i] - i); + ghost var diffInOld := diff in diffMap; + ghost var oldDiffCosets := diffCosets; + ghost var oldGoodPairs := goodPairs; + if diffInOld { + DiffMapKeysPos(nums, diffMap, i); + IndicesCosetsContinuedPos(nums, i, diffCosets, diffMap, diff); + IndicesCosetElementsLessThanI(nums, i, diff); + GoodPairsIIPosContinued(nums, i, goodCount, diffMap, diffCosets, diff, goodPairs); + } else { + DiffMapKeysNeg(nums, diffMap, i); + IndicesCosetsContinuedNeg(nums, i, diffCosets, diffMap, diff); + goodPairsIINegContinued(nums, i, diffMap, diffCosets, diff, goodPairs); + } + ghost var npair: set := + if diffInOld then CosetToPairInPlusOne(oldDiffCosets[diff], nums, i) else {}; + var count := (if (diff in diffMap) then (var i_oc0_val := diffMap[diff]; i_oc0_val) else 0); + goodCount := (goodCount + count); + diffMap := diffMap[diff := (count + 1)]; + if diffInOld { + diffCosets := oldDiffCosets[diff := oldDiffCosets[diff] + {i}]; + } else { + diffCosets := oldDiffCosets[diff := {i}]; + } + goodPairs := oldGoodPairs + npair; + i := (i + 1); + } + assert nums[0..n] == nums; + GoodPairsIIEqGoodPairsAtN(nums); + assert goodCount == |GoodPairs(nums)|; + BadPairsSize(nums); + GoodPairsSize(nums); + AllPairsSize(nums); + BadPairsImplEqBadPairs(nums); + return ((((n - 1) * n) / 2) - goodCount); +} diff --git a/examples/countBadPairs.dfy.gen b/examples/countBadPairs.dfy.gen new file mode 100644 index 0000000..515e6dc --- /dev/null +++ b/examples/countBadPairs.dfy.gen @@ -0,0 +1,154 @@ +// Generated by lsc from countBadPairs.ts + +datatype Pair = Pair(i: nat, j: nat) + +function allPairsImpl(nums: seq): set +{ +} +by method { + var result: set := {}; + var i := 0; + while (i < |nums|) + invariant (0 <= i) + invariant (i <= |nums|) + invariant forall p: Pair :: ((p in result) ==> (((p.i < i) && (p.j < |nums|)) && (p.i < p.j))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> (Pair(x, y) in result)) + decreases (|nums| - i) + { + var j := (i + 1); + while (j < |nums|) + invariant ((i + 1) <= j) + invariant (j <= |nums|) + invariant forall p: Pair :: ((p in result) ==> ((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) || (((p.i == i) && (p.i < p.j)) && (p.j < j)))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> (Pair(x, y) in result)) + invariant forall y: nat :: ((i < y) ==> (y < j) ==> (Pair(i, y) in result)) + decreases (|nums| - j) + { + result := (result + {Pair(i, j)}); + j := (j + 1); + } + i := (i + 1); + } + return result; +} + +function goodPairsImpl(nums: seq): set +{ +} +by method { + var result: set := {}; + var i := 0; + while (i < |nums|) + invariant (0 <= i) + invariant (i <= |nums|) + invariant forall p: Pair :: ((p in result) ==> ((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) == (p.j - p.i)))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) == (y - x)) ==> (Pair(x, y) in result)) + decreases (|nums| - i) + { + var j := (i + 1); + while (j < |nums|) + invariant ((i + 1) <= j) + invariant (j <= |nums|) + invariant forall p: Pair :: ((p in result) ==> (((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) == (p.j - p.i))) || ((((p.i == i) && (p.i < p.j)) && (p.j < j)) && ((nums[p.j] - nums[p.i]) == (p.j - p.i))))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) == (y - x)) ==> (Pair(x, y) in result)) + invariant forall y: nat :: ((i < y) ==> (y < j) ==> ((nums[y] - nums[i]) == (y - i)) ==> (Pair(i, y) in result)) + decreases (|nums| - j) + { + if ((nums[j] - nums[i]) == (j - i)) { + result := (result + {Pair(i, j)}); + } + j := (j + 1); + } + i := (i + 1); + } + return result; +} + +function badPairsImpl(nums: seq): set +{ +} +by method { + var result: set := {}; + var i := 0; + while (i < |nums|) + invariant (0 <= i) + invariant (i <= |nums|) + invariant forall p: Pair :: ((p in result) ==> ((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) != (p.j - p.i)))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) != (y - x)) ==> (Pair(x, y) in result)) + decreases (|nums| - i) + { + var j := (i + 1); + while (j < |nums|) + invariant ((i + 1) <= j) + invariant (j <= |nums|) + invariant forall p: Pair :: ((p in result) ==> (((((p.i < i) && (p.j < |nums|)) && (p.i < p.j)) && ((nums[p.j] - nums[p.i]) != (p.j - p.i))) || ((((p.i == i) && (p.i < p.j)) && (p.j < j)) && ((nums[p.j] - nums[p.i]) != (p.j - p.i))))) + invariant forall x: nat, y: nat :: ((x < i) ==> (y < |nums|) ==> (x < y) ==> ((nums[y] - nums[x]) != (y - x)) ==> (Pair(x, y) in result)) + invariant forall y: nat :: ((i < y) ==> (y < j) ==> ((nums[y] - nums[i]) != (y - i)) ==> (Pair(i, y) in result)) + decreases (|nums| - j) + { + if ((nums[j] - nums[i]) != (j - i)) { + result := (result + {Pair(i, j)}); + } + j := (j + 1); + } + i := (i + 1); + } + return result; +} + +method countBadPairsNaive(nums: seq) returns (res: int) + requires (|nums| > 0) + ensures (res >= 0) + ensures (res == |badPairsImpl(nums)|) +{ + var count := 0; + var n := |nums|; + var i := 0; + while (i < (n - 1)) + invariant (0 <= i) + invariant (i <= (n - 1)) + invariant (count >= 0) + decreases ((n - 1) - i) + { + var iprime := (nums[i] - i); + var k := (i + 1); + while (k < n) + invariant ((i + 1) <= k) + invariant (k <= n) + invariant (count >= 0) + decreases (n - k) + { + if ((nums[k] - k) == iprime) { + count := (count + 1); + } + k := (k + 1); + } + i := (i + 1); + } + var pairs := ((((n - 1) * n) / 2) - count); + return pairs; +} + +method countBadPairs(nums: seq) returns (res: int) + requires (|nums| > 0) + ensures (res >= 0) + ensures (res == |badPairsImpl(nums)|) +{ + var goodCount := 0; + var n := |nums|; + var diffMap: map := map[]; + var i := 0; + while (i < n) + invariant (0 <= i) + invariant (i <= n) + invariant (goodCount >= 0) + decreases (n - i) + { + var diff := (nums[i] - i); + var count := (if (diff in diffMap) then (var i_oc0_val := diffMap[diff]; i_oc0_val) else 0); + goodCount := (goodCount + count); + diffMap := diffMap[diff := (count + 1)]; + i := (i + 1); + } + return ((((n - 1) * n) / 2) - goodCount); +} diff --git a/examples/countBadPairs.ts b/examples/countBadPairs.ts new file mode 100644 index 0000000..1c1ede4 --- /dev/null +++ b/examples/countBadPairs.ts @@ -0,0 +1,165 @@ +//@ backend dafny + +/** + * Count Bad Pairs (LeetCode 2364). + * + * A "good" pair (i, j) with i < j satisfies j - i == nums[j] - nums[i], + * equivalently nums[j] - j == nums[i] - i. A bad pair is any other + * ordered pair with i < j. There are n*(n-1)/2 ordered pairs total, + * so badPairs = total - goodPairs. + * + * The file contains: + * 1. Pair — record type for an (i, j) pair of nat indices. + * 2. allPairsImpl / goodPairsImpl / badPairsImpl — iterative + * reference specifications that build the sets explicitly. + * Marked //@ pure so they can appear in `ensures` clauses; + * LemmaScript emits each as a Dafny `function by method`, + * with the spec body filled in (in countBadPairs.dfy) as the + * matching set comprehension. Dafny verifies the imperative + * body produces the comprehension. + * 3. countBadPairsNaive (O(n^2)) and countBadPairs (O(n) via diffMap), + * both with `\result === badPairsImpl(nums).size` postcondition. + * The deep proof connecting the counter to |BadPairs(nums)| + * lives in countBadPairs.dfy. + */ + +interface Pair { + i: number //@ type nat + j: number //@ type nat +} + +//@ pure +export function allPairsImpl(nums: number[]): Set { + //@ type i nat + //@ type j nat + let result = new Set(); + let i = 0; + while (i < nums.length) { + //@ invariant 0 <= i && i <= nums.length + //@ invariant forall(p: Pair, p in result ==> p.i < i && p.j < nums.length && p.i < p.j) + //@ invariant forall(x: nat, forall(y: nat, x < i && y < nums.length && x < y ==> {i: x, j: y} in result)) + //@ decreases nums.length - i + let j = i + 1; + while (j < nums.length) { + //@ invariant i + 1 <= j && j <= nums.length + //@ invariant forall(p: Pair, p in result ==> (p.i < i && p.j < nums.length && p.i < p.j) || (p.i === i && p.i < p.j && p.j < j)) + //@ invariant forall(x: nat, forall(y: nat, x < i && y < nums.length && x < y ==> {i: x, j: y} in result)) + //@ invariant forall(y: nat, i < y && y < j ==> {i: i, j: y} in result) + //@ decreases nums.length - j + result.add({ i: i, j: j }); + j = j + 1; + } + i = i + 1; + } + return result; +} + +//@ pure +export function goodPairsImpl(nums: number[]): Set { + //@ type i nat + //@ type j nat + let result = new Set(); + let i = 0; + while (i < nums.length) { + //@ invariant 0 <= i && i <= nums.length + //@ invariant forall(p: Pair, p in result ==> p.i < i && p.j < nums.length && p.i < p.j && nums[p.j] - nums[p.i] === p.j - p.i) + //@ invariant forall(x: nat, forall(y: nat, x < i && y < nums.length && x < y && nums[y] - nums[x] === y - x ==> {i: x, j: y} in result)) + //@ decreases nums.length - i + let j = i + 1; + while (j < nums.length) { + //@ invariant i + 1 <= j && j <= nums.length + //@ invariant forall(p: Pair, p in result ==> (p.i < i && p.j < nums.length && p.i < p.j && nums[p.j] - nums[p.i] === p.j - p.i) || (p.i === i && p.i < p.j && p.j < j && nums[p.j] - nums[p.i] === p.j - p.i)) + //@ invariant forall(x: nat, forall(y: nat, x < i && y < nums.length && x < y && nums[y] - nums[x] === y - x ==> {i: x, j: y} in result)) + //@ invariant forall(y: nat, i < y && y < j && nums[y] - nums[i] === y - i ==> {i: i, j: y} in result) + //@ decreases nums.length - j + if (nums[j] - nums[i] === j - i) { + result.add({ i: i, j: j }); + } + j = j + 1; + } + i = i + 1; + } + return result; +} + +//@ pure +export function badPairsImpl(nums: number[]): Set { + //@ type i nat + //@ type j nat + let result = new Set(); + let i = 0; + while (i < nums.length) { + //@ invariant 0 <= i && i <= nums.length + //@ invariant forall(p: Pair, p in result ==> p.i < i && p.j < nums.length && p.i < p.j && nums[p.j] - nums[p.i] !== p.j - p.i) + //@ invariant forall(x: nat, forall(y: nat, x < i && y < nums.length && x < y && nums[y] - nums[x] !== y - x ==> {i: x, j: y} in result)) + //@ decreases nums.length - i + let j = i + 1; + while (j < nums.length) { + //@ invariant i + 1 <= j && j <= nums.length + //@ invariant forall(p: Pair, p in result ==> (p.i < i && p.j < nums.length && p.i < p.j && nums[p.j] - nums[p.i] !== p.j - p.i) || (p.i === i && p.i < p.j && p.j < j && nums[p.j] - nums[p.i] !== p.j - p.i)) + //@ invariant forall(x: nat, forall(y: nat, x < i && y < nums.length && x < y && nums[y] - nums[x] !== y - x ==> {i: x, j: y} in result)) + //@ invariant forall(y: nat, i < y && y < j && nums[y] - nums[i] !== y - i ==> {i: i, j: y} in result) + //@ decreases nums.length - j + if (nums[j] - nums[i] !== j - i) { + result.add({ i: i, j: j }); + } + j = j + 1; + } + i = i + 1; + } + return result; +} + +export function countBadPairsNaive(nums: number[]): number { + //@ requires nums.length > 0 + //@ ensures \result >= 0 + //@ ensures \result === badPairsImpl(nums).size + //@ type i nat + //@ type k nat + let count = 0; + const n = nums.length; + + let i = 0; + while (i < n - 1) { + //@ invariant 0 <= i && i <= n - 1 + //@ invariant count >= 0 + //@ decreases (n - 1) - i + const iprime = nums[i] - i; + let k = i + 1; + while (k < n) { + //@ invariant i + 1 <= k && k <= n + //@ invariant count >= 0 + //@ decreases n - k + if (nums[k] - k === iprime) { + count = count + 1; + } + k = k + 1; + } + i = i + 1; + } + + const pairs = (n - 1) * n / 2 - count; + return pairs; +} + +export function countBadPairs(nums: number[]): number { + //@ requires nums.length > 0 + //@ ensures \result >= 0 + //@ ensures \result === badPairsImpl(nums).size + //@ type i nat + let goodCount = 0; + const n = nums.length; + const diffMap: Map = new Map(); + let i = 0; + while (i < n) { + //@ invariant 0 <= i && i <= n + //@ invariant goodCount >= 0 + //@ decreases n - i + const diff = nums[i] - i; + const count = diffMap.get(diff) ?? 0; + goodCount = goodCount + count; + diffMap.set(diff, count + 1); + i = i + 1; + } + return ((n - 1) * n / 2) - goodCount; +} From 96aa15487ac6fb1bdae2400ad651e402cf1d6a73 Mon Sep 17 00:00:00 2001 From: Nada Amin Date: Wed, 24 Jun 2026 01:02:49 -0400 Subject: [PATCH 2/2] working version --- examples/countBadPairs.dfy | 51 ++++++++++++++++++++++------------ examples/countBadPairs.dfy.gen | 15 ++++++++-- examples/countBadPairs.ts | 4 +-- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/examples/countBadPairs.dfy b/examples/countBadPairs.dfy index fa66ff3..2cee796 100644 --- a/examples/countBadPairs.dfy +++ b/examples/countBadPairs.dfy @@ -10,6 +10,19 @@ // Equivalence lemmas bridge `xImpl(nums) == X(nums)` since each impl's // spec body delegates to the matching hand-written predicate. +import opened Std.Arithmetic.Mul + +function JSFloorDiv(a: int, b: int): int + requires b != 0 +{ + if b > 0 then + if a >= 0 then a / b + else -((-a - 1) / b) - 1 + else + if a <= 0 then (-a) / (-b) + else -((a - 1) / (-b)) - 1 +} + datatype Pair = Pair(i: nat, j: nat) // ═════════════════════ Iterative impls (function by method) ═════════════════════ @@ -243,33 +256,37 @@ lemma SetSizes(s1: set, s2: set) } } -lemma {:vcs_split_on_every_assert} AllPairsSize(nums: seq) - ensures |AllPairs(nums)| == |nums| * (|nums| - 1) / 2 +// Doubled form: a degree-2 polynomial identity (no division), proven by induction; +// AllPairsSize divides once at the end. Avoids nonlinear division reasoning. +lemma AllPairsSizeTimes2(nums: seq) + ensures 2 * |AllPairs(nums)| == |nums| * (|nums| - 1) { if |nums| <= 1 { assert |AllPairs(nums)| == 0; } else { + var n := |nums|; assert nums == [nums[0]] + nums[1..]; - AllPairsSize(nums[1..]); + AllPairsSizeTimes2(nums[1..]); AllPairsEqual(nums); IncrementPairsSize(AllPairs(nums[1..])); ZeroPairsSize(|nums| - 1); assert ZeroPairs(|nums| - 1) !! IncrementPairs(AllPairs(nums[1..])); - calc { - |AllPairs(nums)|; - |ZeroPairs(|nums| - 1)| + |IncrementPairs(AllPairs(nums[1..]))|; - |nums| - 1 + |IncrementPairs(AllPairs(nums[1..]))|; - |nums| - 1 + (|nums| - 1) * (|nums| - 2) / 2; - |nums| - 1 + (|nums| * |nums| - 3 * |nums| + 2) / 2; - 2 * (|nums| - 1) / 2 + (|nums| * |nums| - 3 * |nums| + 2) / 2; - (2 * |nums| - 2) / 2 + (|nums| * |nums| - 3 * |nums| + 2) / 2; - (2 * |nums| - 2 + |nums| * |nums| - 3 * |nums| + 2) / 2; - (|nums| * |nums| - 1 * |nums|) / 2; - |nums| * (|nums| - 1) / 2; - } + assert |AllPairs(nums)| == (n - 1) + |AllPairs(nums[1..])|; + // 2(n-1) + (n-1)(n-2) == n(n-1) — distributivity/commutativity via Std (Z3's + // own nonlinear search times out on this degree-2 identity). + LemmaMulIsDistributiveAdd(n - 1, 2, n - 2); + LemmaMulIsCommutative(n - 1, 2); + LemmaMulIsCommutative(n - 1, n); + assert 2 + (n - 2) == n; } } +lemma AllPairsSize(nums: seq) + ensures |AllPairs(nums)| == |nums| * (|nums| - 1) / 2 +{ + AllPairsSizeTimes2(nums); +} + lemma GoodPairsIkNextPos(nums: seq, idx: nat, k: nat) requires idx < |nums| requires idx < k < |nums| @@ -540,7 +557,7 @@ method countBadPairsNaive(nums: seq) returns (res: int) GoodPairsSize(nums); AllPairsSize(nums); BadPairsImplEqBadPairs(nums); - var pairs := ((((n - 1) * n) / 2) - count); + var pairs := (JSFloorDiv(((n - 1) * n), 2) - count); return pairs; } @@ -603,5 +620,5 @@ method countBadPairs(nums: seq) returns (res: int) GoodPairsSize(nums); AllPairsSize(nums); BadPairsImplEqBadPairs(nums); - return ((((n - 1) * n) / 2) - goodCount); + return (JSFloorDiv(((n - 1) * n), 2) - goodCount); } diff --git a/examples/countBadPairs.dfy.gen b/examples/countBadPairs.dfy.gen index 515e6dc..92e5abd 100644 --- a/examples/countBadPairs.dfy.gen +++ b/examples/countBadPairs.dfy.gen @@ -1,5 +1,16 @@ // Generated by lsc from countBadPairs.ts +function JSFloorDiv(a: int, b: int): int + requires b != 0 +{ + if b > 0 then + if a >= 0 then a / b + else -((-a - 1) / b) - 1 + else + if a <= 0 then (-a) / (-b) + else -((a - 1) / (-b)) - 1 +} + datatype Pair = Pair(i: nat, j: nat) function allPairsImpl(nums: seq): set @@ -125,7 +136,7 @@ method countBadPairsNaive(nums: seq) returns (res: int) } i := (i + 1); } - var pairs := ((((n - 1) * n) / 2) - count); + var pairs := (JSFloorDiv(((n - 1) * n), 2) - count); return pairs; } @@ -150,5 +161,5 @@ method countBadPairs(nums: seq) returns (res: int) diffMap := diffMap[diff := (count + 1)]; i := (i + 1); } - return ((((n - 1) * n) / 2) - goodCount); + return (JSFloorDiv(((n - 1) * n), 2) - goodCount); } diff --git a/examples/countBadPairs.ts b/examples/countBadPairs.ts index 1c1ede4..eecc023 100644 --- a/examples/countBadPairs.ts +++ b/examples/countBadPairs.ts @@ -138,7 +138,7 @@ export function countBadPairsNaive(nums: number[]): number { i = i + 1; } - const pairs = (n - 1) * n / 2 - count; + const pairs = Math.floor((n - 1) * n / 2) - count; return pairs; } @@ -161,5 +161,5 @@ export function countBadPairs(nums: number[]): number { diffMap.set(diff, count + 1); i = i + 1; } - return ((n - 1) * n / 2) - goodCount; + return Math.floor((n - 1) * n / 2) - goodCount; }