Skip to content

Commit 9fbbd6a

Browse files
authored
refactor: use Qq in free_union (#433)
This is faster, because it performs almost all the elaboration at compile-type, rather than when the elaborator is run. I also took the liberty of removing the `id` applications that are presumably not intended.
1 parent ea86c67 commit 9fbbd6a

File tree

2 files changed

+20
-28
lines changed

2 files changed

+20
-28
lines changed

Cslib/Foundations/Data/HasFresh.lean

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module -- shake: keep-downstream
88

99
public import Cslib.Init
1010
public import Mathlib.Analysis.Normed.Field.Lemmas
11+
import Qq
1112

1213
@[expose] public section
1314

@@ -58,13 +59,13 @@ declare_config_elab elabFreeUnionConfig FreeUnionConfig
5859
def f (_ : String) : Finset ℕ := {1, 2, 3}
5960
def g (_ : String) : Finset ℕ := {4, 5, 6}
6061
61-
-- info: ∅ ∪ {x} ∪ id xs : Finset ℕ
62+
-- info: ∅ ∪ {x} ∪ xs : Finset ℕ
6263
#check free_union ℕ
6364
64-
-- info: ∅ ∪ {x} ∪ id xs ∪ f var ∪ g var : Finset ℕ
65+
-- info: ∅ ∪ {x} ∪ xs ∪ f var ∪ g var : Finset ℕ
6566
#check free_union [f, g] ℕ
6667
67-
info: ∅ ∪ id xs : Finset ℕ
68+
info: ∅ ∪ xs : Finset ℕ
6869
#check free_union (singleton := false) ℕ
6970
7071
-- info: ∅ ∪ {x} : Finset ℕ
@@ -76,6 +77,8 @@ declare_config_elab elabFreeUnionConfig FreeUnionConfig
7677
-/
7778
syntax (name := freeUnion) "free_union" optConfig (" [" (term,*) "]")? term : term
7879

80+
open Qq
81+
7982
set_option linter.style.emptyLine false in
8083
/-- Elaborator for `free_union`. -/
8184
@[term_elab freeUnion]
@@ -86,44 +89,33 @@ def HasFresh.freeUnion : TermElab := fun stx _ => do
8689

8790
-- the type of our variables
8891
let var ← elabType var
92+
let dl ← getDecLevel var
93+
have α : Q(Type dl) := var
8994

9095
-- maps to variables
9196
let maps := maps.map (·.getElems) |>.getD #[]
92-
let mut maps ← maps.mapM (flip elabTerm none)
93-
94-
-- construct ∅
95-
let dl ← getDecLevel var
96-
let FinsetType := mkApp (mkConst ``Finset [dl]) var
97-
let EmptyCollectionInst ← synthInstance (mkApp (mkConst ``EmptyCollection [dl]) FinsetType)
98-
let empty :=
99-
mkAppN (mkConst ``EmptyCollection.emptyCollection [dl]) #[FinsetType, EmptyCollectionInst]
97+
let mut maps ← maps.mapM (elabTerm · none)
10098

10199
-- singleton variables
102100
if cfg.singleton then
103-
let SingletonInst ← synthInstance <| mkAppN (mkConst ``Singleton [dl, dl]) #[var, FinsetType]
104-
let singleton_map :=
105-
mkAppN (mkConst ``Singleton.singleton [dl, dl]) #[var, FinsetType, SingletonInst]
106-
maps := maps.push singleton_map
101+
maps := maps.push q(Singleton.singleton : $α → Finset $α)
107102

108103
-- any finite sets
109104
if cfg.finset then
110-
let id_map := mkApp (mkConst ``id [← getLevel var]) FinsetType
111-
maps := maps.push id_map
105+
maps := maps.push q((·) : Finset $α → Finset $α)
112106

113-
let mut finsets := #[]
107+
let mut finsets : Array Q(Finset $α) := #[]
114108

115-
for ldecl in (← getLCtx) do
109+
for ldecl in ← getLCtx do
116110
if !ldecl.isImplementationDetail then
117111
let local_type ← ldecl.toExpr |> inferType >=> whnf
118112
for map in maps do
119113
if let Expr.forallE _ dom _ _ := ← inferType map then
120-
if (←isDefEq local_type dom) then
121-
finsets := finsets.push (mkApp map ldecl.toExpr)
114+
if isDefEq local_type dom then
115+
finsets := finsets.push (map.betaRev #[ldecl.toExpr])
122116

123-
-- construct a union fold
124-
let UnionInst ← synthInstance (mkApp (mkConst ``Union [dl]) FinsetType)
125-
let UnionFinset := mkAppN (mkConst ``Union.union [dl]) #[FinsetType, UnionInst]
126-
let union := finsets.foldl (mkApp2 UnionFinset) empty
117+
let _dec : Q(DecidableEq $α) ← synthInstanceQ q(DecidableEq $α)
118+
let union := finsets.foldl (fun a b : Q(Finset $α) => q($a ∪ $b)) q(∅)
127119

128120
return union
129121
| _ => throwUnsupportedSyntax

CslibTests/HasFresh.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ variable (x : ℕ) (xs : Finset ℕ) (var : String)
3232
def f (_ : String) : Finset ℕ := {1, 2, 3}
3333
def g (_ : String) : Finset ℕ := {4, 5, 6}
3434

35-
/-- info: ∅ ∪ {x} ∪ id xs : Finset ℕ -/
35+
/-- info: ∅ ∪ {x} ∪ xs : Finset ℕ -/
3636
#guard_msgs in
3737
#check free_union ℕ
3838

39-
/-- info: ∅ ∪ {x} ∪ id xs ∪ f var ∪ g var : Finset ℕ -/
39+
/-- info: ∅ ∪ {x} ∪ xs ∪ f var ∪ g var : Finset ℕ -/
4040
#guard_msgs in
4141
#check free_union [f, g] ℕ
4242

43-
/-- info: ∅ ∪ id xs : Finset ℕ -/
43+
/-- info: ∅ ∪ xs : Finset ℕ -/
4444
#guard_msgs in
4545
#check free_union (singleton := false) ℕ
4646

0 commit comments

Comments
 (0)