Skip to content

Commit

Permalink
feat: getElem lemmas for Vector operations (#6324)
Browse files Browse the repository at this point in the history
This PR adds `GetElem` lemmas for the basic `Vector` operations.

The `Vector` API is still very sparse, but I'm hoping to infill rapidly.
  • Loading branch information
kim-em authored Dec 6, 2024
1 parent 019f8e1 commit 6e60d13
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -785,11 +785,26 @@ theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat
else
simp [setIfInBounds, h]

theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (j : Nat)
(hj : j < (setIfInBounds a i v).size) :
(setIfInBounds a i v)[j]'hj = if i = j then v else a[j]'(by simpa using hj) := by
simp only [setIfInBounds]
split
· simp [getElem_set]
· simp only [size_setIfInBounds] at hj
rw [if_neg]
omega

@[simp] theorem getElem_setIfInBounds_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setIfInBounds a i v)[i]'h = v := by
simp at h
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq]

@[simp] theorem getElem_setIfInBounds_ne (a : Array α) {i : Nat} (v : α) {j : Nat}
(hj : j < (setIfInBounds a i v).size) (h : i ≠ j) :
(setIfInBounds a i v)[j]'hj = a[j]'(by simpa using hj) := by
simp [getElem_setIfInBounds, h]

@[simp]
theorem getElem?_setIfInBounds_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) :
(a.setIfInBounds i v)[i]? = some v := by
Expand Down Expand Up @@ -991,11 +1006,6 @@ theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a
(h : i ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h]

theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (h : i < (setIfInBounds a i v).size) :
(setIfInBounds a i v)[i] = v := by
simp at h
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq]

theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
(a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set]

Expand Down Expand Up @@ -1861,8 +1871,6 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) :=

/-! ### swap -/

open Fin

@[simp] theorem getElem_swap_right (a : Array α) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by
simp [swap_def, getElem_set]
Expand All @@ -1881,7 +1889,7 @@ theorem getElem_swap' (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < a.s
· simp_all only [getElem_swap_left]
· split <;> simp_all

theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.swap i j).size) :
theorem getElem_swap (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < (a.swap i j).size) :
(a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
apply getElem_swap'

Expand Down Expand Up @@ -1944,6 +1952,13 @@ theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size)
(as.zip bs).size = min as.size bs.size :=
as.size_zipWith bs Prod.mk

@[simp] theorem getElem_zipWith (as : Array α) (bs : Array β) (f : α → β → γ) (i : Nat)
(hi : i < (as.zipWith bs f).size) :
(as.zipWith bs f)[i] = f (as[i]'(by simp at hi; omega)) (bs[i]'(by simp at hi; omega)) := by
cases as
cases bs
simp

/-! ### findSomeM?, findM?, findSome?, find? -/

@[simp] theorem findSomeM?_toList [Monad m] [LawfulMonad m] (p : α → m (Option β)) (as : Array α) :
Expand Down Expand Up @@ -2244,6 +2259,11 @@ theorem foldr_map' (g : α → β) (f : α → α → α) (f' : β → β → β
cases as
simp

@[simp] theorem getElem_reverse (as : Array α) (i : Nat) (hi : i < as.reverse.size) :
(as.reverse)[i] = as[as.size - 1 - i]'(by simp at hi; omega) := by
cases as
simp [Array.getElem_reverse]

/-! ### findSomeRevM?, findRevM?, findSomeRev?, findRev? -/

@[simp] theorem findSomeRevM?_eq_findSomeM?_reverse
Expand Down
132 changes: 132 additions & 0 deletions src/Init/Data/Vector/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a
(Vector.mk a h).eraseIdx! i = Vector.mk (a.eraseIdx i) (by simp [h, hi]) := by
simp [Vector.eraseIdx!, hi]

@[simp] theorem cast_mk (a : Array α) (h : a.size = n) (h' : n = m) :
(Vector.mk a h).cast h' = Vector.mk a (by simp [h, h']) := rfl

@[simp] theorem extract_mk (a : Array α) (h : a.size = n) (start stop) :
(Vector.mk a h).extract start stop = Vector.mk (a.extract start stop) (by simp [h]) := rfl

Expand Down Expand Up @@ -194,6 +197,9 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a
(a.eraseIdx! i).toArray = a.toArray.eraseIdx! i := by
cases a; simp_all [Array.eraseIdx!]

@[simp] theorem toArray_cast (a : Vector α n) (h : n = m) :
(a.cast h).toArray = a.toArray := rfl

@[simp] theorem toArray_extract (a : Vector α n) (start stop) :
(a.extract start stop).toArray = a.toArray.extract start stop := rfl

Expand Down Expand Up @@ -253,6 +259,132 @@ theorem toList_inj {a b : Vector α n} (h : a.toList = b.toList) : a = b := by
rcases b with ⟨⟨b⟩, hb⟩
simpa using h

/-! ### set -/

theorem getElem_set (a : Vector α n) (i : Nat) (x : α) (hi : i < n) (j : Nat) (hj : j < n) :
(a.set i x hi)[j] = if i = j then x else a[j] := by
cases a
split <;> simp_all [Array.getElem_set]

@[simp] theorem getElem_set_eq (a : Vector α n) (i : Nat) (x : α) (hi : i < n) :
(a.set i x hi)[i] = x := by simp [getElem_set]

@[simp] theorem getElem_set_ne (a : Vector α n) (i : Nat) (x : α) (hi : i < n) (j : Nat)
(hj : j < n) (h : i ≠ j) : (a.set i x hi)[j] = a[j] := by simp [getElem_set, h]

/-! ### setIfInBounds -/

theorem getElem_setIfInBounds (a : Vector α n) (i : Nat) (x : α) (j : Nat)
(hj : j < n) : (a.setIfInBounds i x)[j] = if i = j then x else a[j] := by
cases a
split <;> simp_all [Array.getElem_setIfInBounds]

@[simp] theorem getElem_setIfInBounds_eq (a : Vector α n) (i : Nat) (x : α) (hj : i < n) :
(a.setIfInBounds i x)[i] = x := by simp [getElem_setIfInBounds]

@[simp] theorem getElem_setIfInBounds_ne (a : Vector α n) (i : Nat) (x : α) (j : Nat)
(hj : j < n) (h : i ≠ j) : (a.setIfInBounds i x)[j] = a[j] := by simp [getElem_setIfInBounds, h]

/-! ### append -/

theorem getElem_append (a : Vector α n) (b : Vector α m) (i : Nat) (hi : i < n + m) :
(a ++ b)[i] = if h : i < n then a[i] else b[i - n] := by
rcases a with ⟨a, rfl⟩
rcases b with ⟨b, rfl⟩
simp [Array.getElem_append, hi]

theorem getElem_append_left {a : Vector α n} {b : Vector α m} {i : Nat} (hi : i < n) :
(a ++ b)[i] = a[i] := by simp [getElem_append, hi]

theorem getElem_append_right {a : Vector α n} {b : Vector α m} {i : Nat} (h : i < n + m) (hi : n ≤ i) :
(a ++ b)[i] = b[i - n] := by
rw [getElem_append, dif_neg (by omega)]

/-! ### cast -/

@[simp] theorem getElem_cast (a : Vector α n) (h : n = m) (i : Nat) (hi : i < m) :
(a.cast h)[i] = a[i] := by
cases a
simp

/-! ### extract -/

@[simp] theorem getElem_extract (a : Vector α n) (start stop) (i : Nat) (hi : i < min stop n - start) :
(a.extract start stop)[i] = a[start + i] := by
cases a
simp

/-! ### map -/

@[simp] theorem getElem_map (f : α → β) (a : Vector α n) (i : Nat) (hi : i < n) :
(a.map f)[i] = f a[i] := by
cases a
simp

/-! ### zipWith -/

@[simp] theorem getElem_zipWith (f : α → β → γ) (a : Vector α n) (b : Vector β n) (i : Nat)
(hi : i < n) : (zipWith a b f)[i] = f a[i] b[i] := by
cases a
cases b
simp

/-! ### swap -/

theorem getElem_swap (a : Vector α n) (i j : Nat) {hi hj} (k : Nat) (hk : k < n) :
(a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k] := by
cases a
simp_all [Array.getElem_swap]

@[simp] theorem getElem_swap_right (a : Vector α n) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by
simp +contextual [getElem_swap]

@[simp] theorem getElem_swap_left (a : Vector α n) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[i]'(by simpa using hi) = a[j] := by
simp [getElem_swap]

@[simp] theorem getElem_swap_of_ne (a : Vector α n) {i j : Nat} {hi hj} (hp : p < n)
(hi' : p ≠ i) (hj' : p ≠ j) : (a.swap i j hi hj)[p] = a[p] := by
simp_all [getElem_swap]

@[simp] theorem swap_swap (a : Vector α n) {i j : Nat} {hi hj} :
(a.swap i j hi hj).swap i j hi hj = a := by
cases a
simp_all [Array.swap_swap]

theorem swap_comm (a : Vector α n) {i j : Nat} {hi hj} :
a.swap i j hi hj = a.swap j i hj hi := by
cases a
simp only [swap_mk, mk.injEq]
rw [Array.swap_comm]

/-! ### range -/

@[simp] theorem getElem_range (i : Nat) (hi : i < n) : (Vector.range n)[i] = i := by
simp [Vector.range]

/-! ### take -/

@[simp] theorem getElem_take (a : Vector α n) (m : Nat) (hi : i < min n m) :
(a.take m)[i] = a[i] := by
cases a
simp

/-! ### drop -/

@[simp] theorem getElem_drop (a : Vector α n) (m : Nat) (hi : i < n - m) :
(a.drop m)[i] = a[m + i] := by
cases a
simp

/-! ### reverse -/

@[simp] theorem getElem_reverse (a : Vector α n) (i : Nat) (hi : i < n) :
(a.reverse)[i] = a[n - 1 - i] := by
rcases a with ⟨a, rfl⟩
simp

/-! ### Decidable quantifiers. -/

theorem forall_zero_iff {P : Vector α 0Prop} :
Expand Down

0 comments on commit 6e60d13

Please sign in to comment.