Skip to content

Commit

Permalink
feat: List.mapFinIdx, lemmas, relate to Array version (#5941)
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em authored Nov 4, 2024
1 parent fc17468 commit c779f3a
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 44 deletions.
20 changes: 20 additions & 0 deletions src/Init/Data/Array/MapIdx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ theorem mapFinIdx_spec (as : Array α) (f : Fin as.size → α → β)
simp only [getElem?_def, size_mapFinIdx, getElem_mapFinIdx]
split <;> simp_all

@[simp] theorem toList_mapFinIdx (a : Array α) (f : Fin a.size → α → β) :
(a.mapFinIdx f).toList = a.toList.mapFinIdx (fun i a => f ⟨i, by simp⟩ a) := by
apply List.ext_getElem <;> simp

/-! ### mapIdx -/

theorem mapIdx_induction (as : Array α) (f : Nat → α → β)
Expand Down Expand Up @@ -89,4 +93,20 @@ theorem mapIdx_spec (as : Array α) (f : Nat → α → β)
a[i]?.map (f i) := by
simp [getElem?_def, size_mapIdx, getElem_mapIdx]

@[simp] theorem toList_mapIdx (a : Array α) (f : Nat → α → β) :
(a.mapIdx f).toList = a.toList.mapIdx (fun i a => f i a) := by
apply List.ext_getElem <;> simp

end Array

namespace List

@[simp] theorem mapFinIdx_toArray (l : List α) (f : Fin l.length → α → β) :
l.toArray.mapFinIdx f = (l.mapFinIdx f).toArray := by
ext <;> simp

@[simp] theorem mapIdx_toArray (l : List α) (f : Nat → α → β) :
l.toArray.mapIdx f = (l.mapIdx f).toArray := by
ext <;> simp

end List
246 changes: 203 additions & 43 deletions src/Init/Data/List/MapIdx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,31 @@ Authors: Kim Morrison, Mario Carneiro
prelude
import Init.Data.Array.Lemmas
import Init.Data.List.Nat.Range
import Init.Data.List.OfFn
import Init.Data.Fin.Lemmas
import Init.Data.Option.Attach

namespace List

/-! ## Operations using indexes -/

/-! ### mapIdx -/


/--
Given a list `as = [a₀, a₁, ...]` function `f : Fin as.length → α → β`, returns the list
`[f 0 a₀, f 1 a₁, ...]`.
-/
@[inline] def mapFinIdx (as : List α) (f : Fin as.length → α → β) : List β := go as #[] (by simp) where
/-- Auxiliary for `mapFinIdx`:
`mapFinIdx.go [a₀, a₁, ...] acc = acc.toList ++ [f 0 a₀, f 1 a₁, ...]` -/
@[specialize] go : (bs : List α) → (acc : Array β) → bs.length + acc.size = as.length → List β
| [], acc, h => acc.toList
| a :: as, acc, h =>
go as (acc.push (f ⟨acc.size, by simp at h; omega⟩ a)) (by simp at h ⊢; omega)

/--
Given a function `f : Nat → α → β` and `as : list α`, `as = [a₀, a₁, ...]`, returns the list
Given a function `f : Nat → α → β` and `as : List α`, `as = [a₀, a₁, ...]`, returns the list
`[f 0 a₀, f 1 a₁, ...]`.
-/
@[inline] def mapIdx (f : Nat → α → β) (as : List α) : List β := go as #[] where
Expand All @@ -25,33 +41,176 @@ Given a function `f : Nat → α → β` and `as : list α`, `as = [a₀, a₁,
| [], acc => acc.toList
| a :: as, acc => go as (acc.push (f acc.size a))

/-! ### mapFinIdx -/

@[simp]
theorem mapIdx_nil {f : Nat → α → β} : mapIdx f [] = [] :=
theorem mapFinIdx_nil {f : Fin 0 → α → β} : mapFinIdx [] f = [] :=
rfl

theorem mapIdx_go_append {l₁ l₂ : List α} {arr : Array β} :
mapIdx.go f (l₁ ++ l₂) arr = mapIdx.go f l₂ (List.toArray (mapIdx.go f l₁ arr)) := by
generalize h : (l₁ ++ l₂).length = len
induction len generalizing l₁ arr with
| zero =>
have l₁_nil : l₁ = [] := by
cases l₁
· rfl
· contradiction
have l₂_nil : l₂ = [] := by
cases l₂
· rfl
· rw [List.length_append] at h; contradiction
rw [l₁_nil, l₂_nil]; simp only [mapIdx.go, List.toArray_toList]
| succ len ih =>
cases l₁ with
| nil =>
simp only [mapIdx.go, nil_append, List.toArray_toList]
| cons head tail =>
simp only [mapIdx.go, List.append_eq]
rw [ih]
· simp only [cons_append, length_cons, length_append, Nat.succ.injEq] at h
simp only [length_append, h]
@[simp] theorem length_mapFinIdx_go :
(mapFinIdx.go as f bs acc h).length = as.length := by
induction bs generalizing acc with
| nil => simpa using h
| cons _ _ ih => simp [mapFinIdx.go, ih]

@[simp] theorem length_mapFinIdx {as : List α} {f : Fin as.length → α → β} :
(as.mapFinIdx f).length = as.length := by
simp [mapFinIdx, length_mapFinIdx_go]

theorem getElem_mapFinIdx_go {as : List α} {f : Fin as.length → α → β} {i : Nat} {h} {w} :
(mapFinIdx.go as f bs acc h)[i] =
if w' : i < acc.size then acc[i] else f ⟨i, by simp at w; omega⟩ (bs[i - acc.size]'(by simp at w; omega)) := by
induction bs generalizing acc with
| nil =>
simp only [length_mapFinIdx_go, length_nil, Nat.zero_add] at w h
simp only [mapFinIdx.go, Array.getElem_toList]
rw [dif_pos]
| cons _ _ ih =>
simp [mapFinIdx.go]
rw [ih]
simp
split <;> rename_i h₁ <;> split <;> rename_i h₂
· rw [Array.getElem_push_lt]
· have h₃ : i = acc.size := by omega
subst h₃
simp
· omega
· have h₃ : i - acc.size = (i - (acc.size + 1)) + 1 := by omega
simp [h₃]

@[simp] theorem getElem_mapFinIdx {as : List α} {f : Fin as.length → α → β} {i : Nat} {h} :
(as.mapFinIdx f)[i] = f ⟨i, by simp at h; omega⟩ (as[i]'(by simp at h; omega)) := by
simp [mapFinIdx, getElem_mapFinIdx_go]

theorem mapFinIdx_eq_ofFn {as : List α} {f : Fin as.length → α → β} :
as.mapFinIdx f = List.ofFn fun i : Fin as.length => f i as[i] := by
apply ext_getElem <;> simp

@[simp] theorem getElem?_mapFinIdx {l : List α} {f : Fin l.length → α → β} {i : Nat} :
(l.mapFinIdx f)[i]? = l[i]?.pbind fun x m => f ⟨i, by simp [getElem?_eq_some] at m; exact m.1⟩ x := by
simp only [getElem?_eq, length_mapFinIdx, getElem_mapFinIdx]
split <;> simp

@[simp]
theorem mapFinIdx_cons {l : List α} {a : α} {f : Fin (l.length + 1) → α → β} :
mapFinIdx (a :: l) f = f 0 a :: mapFinIdx l (fun i => f i.succ) := by
apply ext_getElem
· simp
· rintro (_|i) h₁ h₂ <;> simp

theorem mapFinIdx_append {K L : List α} {f : Fin (K ++ L).length → α → β} :
(K ++ L).mapFinIdx f =
K.mapFinIdx (fun i => f (i.castLE (by simp))) ++ L.mapFinIdx (fun i => f ((i.natAdd K.length).cast (by simp))) := by
apply ext_getElem
· simp
· intro i h₁ h₂
rw [getElem_append]
simp only [getElem_mapFinIdx, length_mapFinIdx]
split <;> rename_i h
· rw [getElem_append_left]
congr
· simp only [Nat.not_lt] at h
rw [getElem_append_right h]
congr
simp
omega

@[simp] theorem mapFinIdx_concat {l : List α} {e : α} {f : Fin (l ++ [e]).length → α → β}:
(l ++ [e]).mapFinIdx f = l.mapFinIdx (fun i => f (i.castLE (by simp))) ++ [f ⟨l.length, by simp⟩ e] := by
simp [mapFinIdx_append]
congr

theorem mapFinIdx_singleton {a : α} {f : Fin 1 → α → β} :
[a].mapFinIdx f = [f ⟨0, by simp⟩ a] := by
simp

theorem mapFinIdx_eq_enum_map {l : List α} {f : Fin l.length → α → β} :
l.mapFinIdx f = l.enum.attach.map
fun ⟨⟨i, x⟩, m⟩ => f ⟨i, by rw [mk_mem_enum_iff_getElem?, getElem?_eq_some] at m; exact m.1⟩ x := by
apply ext_getElem <;> simp

@[simp]
theorem mapFinIdx_eq_nil_iff {l : List α} {f : Fin l.length → α → β} :
l.mapFinIdx f = [] ↔ l = [] := by
rw [mapFinIdx_eq_enum_map, map_eq_nil_iff, attach_eq_nil_iff, enum_eq_nil_iff]

theorem mapFinIdx_ne_nil_iff {l : List α} {f : Fin l.length → α → β} :
l.mapFinIdx f ≠ [] ↔ l ≠ [] := by
simp

theorem exists_of_mem_mapFinIdx {b : β} {l : List α} {f : Fin l.length → α → β}
(h : b ∈ l.mapFinIdx f) : ∃ (i : Fin l.length), f i l[i] = b := by
rw [mapFinIdx_eq_enum_map] at h
replace h := exists_of_mem_map h
simp only [mem_attach, true_and, Subtype.exists, Prod.exists, mk_mem_enum_iff_getElem?] at h
obtain ⟨i, b, h, rfl⟩ := h
rw [getElem?_eq_some_iff] at h
obtain ⟨h', rfl⟩ := h
exact ⟨⟨i, h'⟩, rfl⟩

@[simp] theorem mem_mapFinIdx {b : β} {l : List α} {f : Fin l.length → α → β} :
b ∈ l.mapFinIdx f ↔ ∃ (i : Fin l.length), f i l[i] = b := by
constructor
· intro h
exact exists_of_mem_mapFinIdx h
· rintro ⟨i, h, rfl⟩
rw [mem_iff_getElem]
exact ⟨i, by simp⟩

theorem mapFinIdx_eq_cons_iff {l : List α} {b : β} {f : Fin l.length → α → β} :
l.mapFinIdx f = b :: l₂ ↔
∃ (a : α) (l₁ : List α) (h : l = a :: l₁),
f ⟨0, by simp [h]⟩ a = b ∧ l₁.mapFinIdx (fun i => f (i.succ.cast (by simp [h]))) = l₂ := by
cases l with
| nil => simp
| cons x l' =>
simp only [mapFinIdx_cons, cons.injEq, length_cons, Fin.zero_eta, Fin.cast_succ_eq,
exists_and_left]
constructor
· rintro ⟨rfl, rfl⟩
refine ⟨x, rfl, l', by simp⟩
· rintro ⟨a, ⟨rfl, h⟩, ⟨_, ⟨rfl, rfl⟩, h⟩⟩
exact ⟨rfl, h⟩

theorem mapFinIdx_eq_cons_iff' {l : List α} {b : β} {f : Fin l.length → α → β} :
l.mapFinIdx f = b :: l₂ ↔
l.head?.pbind (fun x m => (f ⟨0, by cases l <;> simp_all⟩ x)) = some b ∧
l.tail?.attach.map (fun ⟨t, m⟩ => t.mapFinIdx fun i => f (i.succ.cast (by cases l <;> simp_all))) = some l₂ := by
cases l <;> simp

theorem mapFinIdx_eq_iff {l : List α} {f : Fin l.length → α → β} :
l.mapFinIdx f = l' ↔ ∃ h : l'.length = l.length, ∀ (i : Nat) (h : i < l.length), l'[i] = f ⟨i, h⟩ l[i] := by
constructor
· rintro rfl
simp
· rintro ⟨h, w⟩
apply ext_getElem <;> simp_all

theorem mapFinIdx_eq_mapFinIdx_iff {l : List α} {f g : Fin l.length → α → β} :
l.mapFinIdx f = l.mapFinIdx g ↔ ∀ (i : Fin l.length), f i l[i] = g i l[i] := by
rw [eq_comm, mapFinIdx_eq_iff]
simp [Fin.forall_iff]

@[simp] theorem mapFinIdx_mapFinIdx {l : List α} {f : Fin l.length → α → β} {g : Fin _ → β → γ} :
(l.mapFinIdx f).mapFinIdx g = l.mapFinIdx (fun i => g (i.cast (by simp)) ∘ f i) := by
simp [mapFinIdx_eq_iff]

theorem mapFinIdx_eq_replicate_iff {l : List α} {f : Fin l.length → α → β} {b : β} :
l.mapFinIdx f = replicate l.length b ↔ ∀ (i : Fin l.length), f i l[i] = b := by
simp [eq_replicate_iff, length_mapFinIdx, mem_mapFinIdx, forall_exists_index, true_and]

@[simp] theorem mapFinIdx_reverse {l : List α} {f : Fin l.reverse.length → α → β} :
l.reverse.mapFinIdx f = (l.mapFinIdx (fun i => f ⟨l.length - 1 - i, by simp; omega⟩)).reverse := by
simp [mapFinIdx_eq_iff]
intro i h
congr
omega

/-! ### mapIdx -/

@[simp]
theorem mapIdx_nil {f : Nat → α → β} : mapIdx f [] = [] :=
rfl

theorem mapIdx_go_length {arr : Array β} :
length (mapIdx.go f l arr) = length l + arr.size := by
Expand All @@ -60,16 +219,6 @@ theorem mapIdx_go_length {arr : Array β} :
| cons _ _ ih =>
simp only [mapIdx.go, ih, Array.size_push, Nat.add_succ, length_cons, Nat.add_comm]

@[simp] theorem mapIdx_concat {l : List α} {e : α} :
mapIdx f (l ++ [e]) = mapIdx f l ++ [f l.length e] := by
unfold mapIdx
rw [mapIdx_go_append]
simp only [mapIdx.go, Array.size_toArray, mapIdx_go_length, length_nil, Nat.add_zero,
Array.push_toList]

@[simp] theorem mapIdx_singleton {a : α} : mapIdx f [a] = [f 0 a] := by
simpa using mapIdx_concat (l := [])

theorem length_mapIdx_go : ∀ {l : List α} {arr : Array β},
(mapIdx.go f l arr).length = l.length + arr.size
| [], _ => by simp [mapIdx.go]
Expand Down Expand Up @@ -112,6 +261,15 @@ theorem getElem?_mapIdx_go : ∀ {l : List α} {arr : Array β} {i : Nat},
rw [← getElem?_eq_getElem, getElem?_mapIdx, getElem?_eq_getElem (by simpa using h)]
simp

@[simp] theorem mapFinIdx_eq_mapIdx {l : List α} {f : Fin l.length → α → β} {g : Nat → α → β}
(h : ∀ (i : Fin l.length), f i l[i] = g i l[i]) :
l.mapFinIdx f = l.mapIdx g := by
simp_all [mapFinIdx_eq_iff]

theorem mapIdx_eq_mapFinIdx {l : List α} {f : Nat → α → β} :
l.mapIdx f = l.mapFinIdx (fun i => f i) := by
simp [mapFinIdx_eq_mapIdx]

theorem mapIdx_eq_enum_map {l : List α} :
l.mapIdx f = l.enum.map (Function.uncurry f) := by
ext1 i
Expand All @@ -130,23 +288,25 @@ theorem mapIdx_append {K L : List α} :
| nil => rfl
| cons _ _ ih => simp [ih (f := fun i => f (i + 1)), Nat.add_assoc]

@[simp] theorem mapIdx_concat {l : List α} {e : α} :
mapIdx f (l ++ [e]) = mapIdx f l ++ [f l.length e] := by
simp [mapIdx_append]

theorem mapIdx_singleton {a : α} : mapIdx f [a] = [f 0 a] := by
simp

@[simp]
theorem mapIdx_eq_nil_iff {l : List α} : List.mapIdx f l = [] ↔ l = [] := by
rw [List.mapIdx_eq_enum_map, List.map_eq_nil_iff, List.enum_eq_nil]
rw [List.mapIdx_eq_enum_map, List.map_eq_nil_iff, List.enum_eq_nil_iff]

theorem mapIdx_ne_nil_iff {l : List α} :
List.mapIdx f l ≠ [] ↔ l ≠ [] := by
simp

theorem exists_of_mem_mapIdx {b : β} {l : List α}
(h : b ∈ mapIdx f l) : ∃ (i : Nat) (h : i < l.length), f i l[i] = b := by
rw [mapIdx_eq_enum_map] at h
replace h := exists_of_mem_map h
simp only [Prod.exists, mk_mem_enum_iff_getElem?, Function.uncurry_apply_pair] at h
obtain ⟨i, b, h, rfl⟩ := h
rw [getElem?_eq_some_iff] at h
obtain ⟨h, rfl⟩ := h
exact ⟨i, h, rfl⟩
rw [mapIdx_eq_mapFinIdx] at h
simpa [Fin.exists_iff] using exists_of_mem_mapFinIdx h

@[simp] theorem mem_mapIdx {b : β} {l : List α} :
b ∈ mapIdx f l ↔ ∃ (i : Nat) (h : i < l.length), f i l[i] = b := by
Expand Down
5 changes: 4 additions & 1 deletion src/Init/Data/List/Nat/Range.lean
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,10 @@ theorem enumFrom_eq_append_iff {l : List α} {n : Nat} :
/-! ### enum -/

@[simp]
theorem enum_eq_nil {l : List α} : List.enum l = [] ↔ l = [] := enumFrom_eq_nil
theorem enum_eq_nil_iff {l : List α} : List.enum l = [] ↔ l = [] := enumFrom_eq_nil

@[deprecated enum_eq_nil_iff (since := "2024-11-04")]
theorem enum_eq_nil {l : List α} : List.enum l = [] ↔ l = [] := enum_eq_nil_iff

@[simp] theorem enum_singleton (x : α) : enum [x] = [(0, x)] := rfl

Expand Down

0 comments on commit c779f3a

Please sign in to comment.