Skip to content

Commit

Permalink
fix: propagate Simp.Config when reducing terms and checking definit…
Browse files Browse the repository at this point in the history
…ional equality in `simp` (#6123)

This PR ensures that the configuration in `Simp.Config` is used when
reducing terms and checking definitional equality in `simp`.

closes #5455

---------

Co-authored-by: Kim Morrison <[email protected]>
  • Loading branch information
leodemoura and kim-em authored Dec 14, 2024
1 parent aa00725 commit 19eac5f
Show file tree
Hide file tree
Showing 30 changed files with 1,350 additions and 150 deletions.
2 changes: 1 addition & 1 deletion src/Init/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ decreasing_by simp_wf; exact Nat.sub_succ_lt_self _ _ h
induction a, i, h using Array.eraseIdx.induct with
| @case1 a i h h' a' ih =>
unfold eraseIdx
simp [h', a', ih]
simp +zetaDelta [h', a', ih]
| case2 a i h h' =>
unfold eraseIdx
simp [h']
Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/List/Impl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def enumFromTR (n : Nat) (l : List α) : List (Nat × α) :=
rw [← show _ + as.length = n + (a::as).length from Nat.succ_add .., foldr, go as]
simp [enumFrom, f]
rw [← Array.foldr_toList]
simp [go]
simp +zetaDelta [go]

/-! ## Other list operations -/

Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Data/RArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ where
go lb ub h1 h2 : (ofFn.go f lb ub h1 h2).size = ub - lb := by
induction lb, ub, h1, h2 using RArray.ofFn.go.induct (n := n)
case case1 => simp [ofFn.go, size]; omega
case case2 ih1 ih2 hiu => rw [ofFn.go]; simp [size, *]; omega
case case2 ih1 ih2 hiu => rw [ofFn.go]; simp +zetaDelta [size, *]; omega

section Meta
open Lean
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/PatternVar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private def processVar (idStx : Syntax) : M Syntax := do
private def samePatternsVariables (startingAt : Nat) (s₁ s₂ : State) : Bool := Id.run do
if h₁ : s₁.vars.size = s₂.vars.size then
for h₂ : i in [startingAt:s₁.vars.size] do
if s₁.vars[i] != s₂.vars[i]'(by obtain ⟨_, y⟩ := h₂; simp_all) then return false
if s₁.vars[i] != s₂.vars[i]'(by obtain ⟨_, y⟩ := h₂; simp_all +zetaDelta) then return false
true
else
false
Expand Down
139 changes: 78 additions & 61 deletions src/Lean/Elab/Tactic/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -173,68 +173,71 @@ def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsAr
syntax simpErase := "-" ident
-/
let go := withMainContext do
let mut thmsArray := ctx.simpTheorems
let mut thms := thmsArray[0]!
let mut simprocs := simprocs
let mut starArg := false
for arg in stx[1].getSepArgs do
try -- like withLogging, but compatible with do-notation
if arg.getKind == ``Lean.Parser.Tactic.simpErase then
let fvar? ← if eraseLocal || starArg then Term.isLocalIdent? arg[1] else pure none
if let some fvar := fvar? then
-- We use `eraseCore` because the simp theorem for the hypothesis was not added yet
thms := thms.eraseCore (.fvar fvar.fvarId!)
else
let id := arg[1]
if let .ok declName ← observing (realizeGlobalConstNoOverloadWithInfo id) then
if (← Simp.isSimproc declName) then
simprocs := simprocs.erase declName
else if ctx.config.autoUnfold then
thms := thms.eraseCore (.decl declName)
else
thms ← withRef id <| thms.erase (.decl declName)
let zetaDeltaSet ← toZetaDeltaSet stx ctx
withTrackingZetaDeltaSet zetaDeltaSet do
let mut thmsArray := ctx.simpTheorems
let mut thms := thmsArray[0]!
let mut simprocs := simprocs
let mut starArg := false
for arg in stx[1].getSepArgs do
try -- like withLogging, but compatible with do-notation
if arg.getKind == ``Lean.Parser.Tactic.simpErase then
let fvar? ← if eraseLocal || starArg then Term.isLocalIdent? arg[1] else pure none
if let some fvar := fvar? then
-- We use `eraseCore` because the simp theorem for the hypothesis was not added yet
thms := thms.eraseCore (.fvar fvar.fvarId!)
else
-- If `id` could not be resolved, we should check whether it is a builtin simproc.
-- before returning error.
let name := id.getId.eraseMacroScopes
if (← Simp.isBuiltinSimproc name) then
simprocs := simprocs.erase name
let id := arg[1]
if let .ok declName ← observing (realizeGlobalConstNoOverloadWithInfo id) then
if (← Simp.isSimproc declName) then
simprocs := simprocs.erase declName
else if ctx.config.autoUnfold then
thms := thms.eraseCore (.decl declName)
else
thms ← withRef id <| thms.erase (.decl declName)
else
withRef id <| throwUnknownConstant name
else if arg.getKind == ``Lean.Parser.Tactic.simpLemma then
let post :=
if arg[0].isNone then
true
else
arg[0][0].getKind == ``Parser.Tactic.simpPost
let inv := !arg[1].isNone
let term := arg[2]
match (← resolveSimpIdTheorem? term) with
| .expr e =>
let name ← mkFreshId
thms ← addDeclToUnfoldOrTheorem ctx.indexConfig thms (.stx name arg) e post inv kind
| .simproc declName =>
simprocs ← simprocs.add declName post
| .ext (some ext₁) (some ext₂) _ =>
thmsArray := thmsArray.push (← ext₁.getTheorems)
simprocs := simprocs.push (← ext₂.getSimprocs)
| .ext (some ext₁) none _ =>
thmsArray := thmsArray.push (← ext₁.getTheorems)
| .ext none (some ext₂) _ =>
simprocs := simprocs.push (← ext₂.getSimprocs)
| .none =>
let name ← mkFreshId
thms ← addSimpTheorem ctx.indexConfig thms (.stx name arg) term post inv
else if arg.getKind == ``Lean.Parser.Tactic.simpStar then
starArg := true
else
throwUnsupportedSyntax
catch ex =>
if (← read).recover then
logException ex
else
throw ex
return { ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms), simprocs, starArg }
-- If `id` could not be resolved, we should check whether it is a builtin simproc.
-- before returning error.
let name := id.getId.eraseMacroScopes
if (← Simp.isBuiltinSimproc name) then
simprocs := simprocs.erase name
else
withRef id <| throwUnknownConstant name
else if arg.getKind == ``Lean.Parser.Tactic.simpLemma then
let post :=
if arg[0].isNone then
true
else
arg[0][0].getKind == ``Parser.Tactic.simpPost
let inv := !arg[1].isNone
let term := arg[2]
match (← resolveSimpIdTheorem? term) with
| .expr e =>
let name ← mkFreshId
thms ← addDeclToUnfoldOrTheorem ctx.indexConfig thms (.stx name arg) e post inv kind
| .simproc declName =>
simprocs ← simprocs.add declName post
| .ext (some ext₁) (some ext₂) _ =>
thmsArray := thmsArray.push (← ext₁.getTheorems)
simprocs := simprocs.push (← ext₂.getSimprocs)
| .ext (some ext₁) none _ =>
thmsArray := thmsArray.push (← ext₁.getTheorems)
| .ext none (some ext₂) _ =>
simprocs := simprocs.push (← ext₂.getSimprocs)
| .none =>
let name ← mkFreshId
thms ← addSimpTheorem ctx.indexConfig thms (.stx name arg) term post inv
else if arg.getKind == ``Lean.Parser.Tactic.simpStar then
starArg := true
else
throwUnsupportedSyntax
catch ex =>
if (← read).recover then
logException ex
else
throw ex
let ctx := ctx.setZetaDeltaSet zetaDeltaSet (← getZetaDeltaFVarIds)
return { ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms), simprocs, starArg }
-- If recovery is disabled, then we want simp argument elaboration failures to be exceptions.
-- This affects `addSimpTheorem`.
if (← read).recover then
Expand Down Expand Up @@ -277,6 +280,20 @@ where
else
return .none

/-- If `zetaDelta := false`, create a `FVarId` set with all local let declarations in the `simp` argument list. -/
toZetaDeltaSet (stx : Syntax) (ctx : Simp.Context) : TacticM FVarIdSet := do
if ctx.config.zetaDelta then return {}
Term.withoutCheckDeprecated do -- We do not want to report deprecated constants in the first pass
let mut s : FVarIdSet := {}
for arg in stx[1].getSepArgs do
if arg.getKind == ``Lean.Parser.Tactic.simpLemma then
if arg[0].isNone && arg[1].isNone then
let term := arg[2]
let .expr (.fvar fvarId) ← resolveSimpIdTheorem? term | pure ()
if (← fvarId.getDecl).isLet then
s := s.insert fvarId
return s

@[inline] def simpOnlyBuiltins : List Name := [``eq_self, ``iff_self]

structure MkSimpContextResult where
Expand Down Expand Up @@ -323,7 +340,7 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp)
let simprocs := r.simprocs
let mut simpTheorems := ctx.simpTheorems
/-
When using `zeta := false`, we do not expand let-declarations when using `[*]`.
When using `zetaDelta := false`, we do not expand let-declarations when using `[*]`.
Users must explicitly include it in the list.
-/
let hs ← getPropHyps
Expand Down
20 changes: 15 additions & 5 deletions src/Lean/Elab/Term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ structure Context where
`refine' (fun x => _)
-/
holesAsSyntheticOpaque : Bool := false
/--
If `checkDeprecated := true`, then `Linter.checkDeprecated` when creating constants.
-/
checkDeprecated : Bool := true

abbrev TermElabM := ReaderT Context $ StateRefT State MetaM
abbrev TermElab := Syntax → Option Expr → TermElabM Expr
Expand Down Expand Up @@ -2026,16 +2030,19 @@ def isLetRecAuxMVar (mvarId : MVarId) : TermElabM Bool := do
trace[Elab.letrec] "mvarId root: {mkMVar mvarId}"
return (← get).letRecsToLift.any (·.mvarId == mvarId)

private def checkDeprecatedCore (constName : Name) : TermElabM Unit := do
if (← read).checkDeprecated then
Linter.checkDeprecated constName

/--
Create an `Expr.const` using the given name and explicit levels.
Remark: fresh universe metavariables are created if the constant has more universe
parameters than `explicitLevels`.
If `checkDeprecated := true`, then `Linter.checkDeprecated` is invoked.
-/
def mkConst (constName : Name) (explicitLevels : List Level := []) (checkDeprecated := true) : TermElabM Expr := do
if checkDeprecated then
Linter.checkDeprecated constName
def mkConst (constName : Name) (explicitLevels : List Level := []) : TermElabM Expr := do
checkDeprecatedCore constName
let cinfo ← getConstInfo constName
if explicitLevels.length > cinfo.levelParams.length then
throwError "too many explicit universe levels for '{constName}'"
Expand All @@ -2046,7 +2053,10 @@ def mkConst (constName : Name) (explicitLevels : List Level := []) (checkDepreca

def checkDeprecated (ref : Syntax) (e : Expr) : TermElabM Unit := do
if let .const declName _ := e.getAppFn then
withRef ref do Linter.checkDeprecated declName
withRef ref do checkDeprecatedCore declName

@[inline] def withoutCheckDeprecated [MonadWithReaderOf Context m] : m α → m α :=
withTheReader Context (fun ctx => { ctx with checkDeprecated := false })

private def mkConsts (candidates : List (Name × List String)) (explicitLevels : List Level) : TermElabM (List (Expr × List String)) := do
candidates.foldlM (init := []) fun result (declName, projs) => do
Expand All @@ -2058,7 +2068,7 @@ private def mkConsts (candidates : List (Name × List String)) (explicitLevels :
At `elabAppFnId`, we perform the check when converting the list returned by `resolveName'` into a list of
`TermElabResult`s.
-/
let const ← mkConst declName explicitLevels (checkDeprecated := false)
let const ← withoutCheckDeprecated <| mkConst declName explicitLevels
return (const, projs) :: result

def resolveName (stx : Syntax) (n : Name) (preresolved : List Syntax.Preresolved) (explicitLevels : List Level) (expectedType? : Option Expr := none) : TermElabM (List (Expr × List String)) := do
Expand Down
70 changes: 54 additions & 16 deletions src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,6 @@ structure Config where
Controls which definitions and theorems can be unfolded by `isDefEq` and `whnf`.
-/
transparency : TransparencyMode := TransparencyMode.default
/--
When `trackZetaDelta = true`, we track all free variables that have been zetaDelta-expanded.
That is, suppose the local context contains
the declaration `x : t := v`, and we reduce `x` to `v`, then we insert `x` into `State.zetaDeltaFVarIds`.
We use `trackZetaDelta` to discover which let-declarations `let x := v; e` can be represented as `(fun x => e) v`.
When we find these declarations we set their `nonDep` flag with `true`.
To find these let-declarations in a given term `s`, we
1- Reset `State.zetaDeltaFVarIds`
2- Set `trackZetaDelta := true`
3- Type-check `s`.
-/
trackZetaDelta : Bool := false
/-- Eta for structures configuration mode. -/
etaStruct : EtaStructMode := .all
/--
Expand Down Expand Up @@ -187,7 +175,7 @@ structure Config where
Zeta-delta reduction: given a local context containing entry `x : t := e`, free variable `x` reduces to `e`.
-/
zetaDelta : Bool := true
deriving Inhabited
deriving Inhabited, Repr

/-- Convert `isDefEq` and `WHNF` relevant parts into a key for caching results -/
private def Config.toKey (c : Config) : UInt64 :=
Expand Down Expand Up @@ -419,7 +407,7 @@ structure Diagnostics where
structure State where
mctx : MetavarContext := {}
cache : Cache := {}
/-- When `trackZetaDelta == true`, then any let-decl free variable that is zetaDelta-expanded by `MetaM` is stored in `zetaDeltaFVarIds`. -/
/-- When `Context.trackZetaDelta == true`, then any let-decl free variable that is zetaDelta-expanded by `MetaM` is stored in `zetaDeltaFVarIds`. -/
zetaDeltaFVarIds : FVarIdSet := {}
/-- Array of postponed universe level constraints -/
postponed : PersistentArray PostponedEntry := {}
Expand All @@ -445,6 +433,28 @@ register_builtin_option maxSynthPendingDepth : Nat := {
structure Context where
private config : Config := {}
private configKey : UInt64 := config.toKey
/--
When `trackZetaDelta = true`, we track all free variables that have been zetaDelta-expanded.
That is, suppose the local context contains
the declaration `x : t := v`, and we reduce `x` to `v`, then we insert `x` into `State.zetaDeltaFVarIds`.
We use `trackZetaDelta` to discover which let-declarations `let x := v; e` can be represented as `(fun x => e) v`.
When we find these declarations we set their `nonDep` flag with `true`.
To find these let-declarations in a given term `s`, we
1- Reset `State.zetaDeltaFVarIds`
2- Set `trackZetaDelta := true`
3- Type-check `s`.
Note that, we do not include this field in the `Config` structure because this field is not
taken into account while caching results. See also field `zetaDeltaSet`.
-/
trackZetaDelta : Bool := false
/--
If `config.zetaDelta := false`, we may select specific local declarations to be unfolded using
the field `zetaDeltaSet`. Note that, we do not include this field in the `Config` structure
because this field is not taken into account while caching results.
Moreover, we reset all caches whenever setting it.
-/
zetaDeltaSet : FVarIdSet := {}
/-- Local context -/
lctx : LocalContext := {}
/-- Local instances in `lctx`. -/
Expand Down Expand Up @@ -1089,8 +1099,36 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) :
/--
Executes `x` tracking zetaDelta reductions `Config.trackZetaDelta := true`
-/
@[inline] def withTrackingZetaDelta (x : n α) : n α :=
withConfig (fun cfg => { cfg with trackZetaDelta := true }) x
@[inline] def withTrackingZetaDelta : n α → n α :=
mapMetaM <| withReader (fun ctx => { ctx with trackZetaDelta := true })

def withZetaDeltaSetImp (s : FVarIdSet) (x : MetaM α) : MetaM α := do
if s.isEmpty then
x
else
let cacheSaved := (← get).cache
modify fun s => { s with cache := {} }
try
withReader (fun ctx => { ctx with zetaDeltaSet := s }) x
finally
modify fun s => { s with cache := cacheSaved }

/--
`withZetaDeltaSet s x` executes `x` with `zetaDeltaSet := s`.
The cache is reset while executing `x` if `s` is not empty.
-/
def withZetaDeltaSet (s : FVarIdSet) : n α → n α :=
mapMetaM <| withZetaDeltaSetImp s

/--
Similar to `withZetaDeltaSet`, but also enables `withTrackingZetaDelta` if
`s` is not empty.
-/
def withTrackingZetaDeltaSet (s : FVarIdSet) (x : n α) : n α := do
if s.isEmpty then
x
else
withZetaDeltaSet s <| withTrackingZetaDelta x

@[inline] def withoutProofIrrelevance (x : n α) : n α :=
withConfig (fun cfg => { cfg with proofIrrelevance := false }) x
Expand Down
Loading

0 comments on commit 19eac5f

Please sign in to comment.