@@ -8,6 +8,7 @@ module -- shake: keep-downstream
88
99public import Cslib.Init
1010public import Mathlib.Analysis.Normed.Field.Lemmas
11+ import Qq
1112
1213@[expose] public section
1314
@@ -58,13 +59,13 @@ declare_config_elab elabFreeUnionConfig FreeUnionConfig
5859 def f (_ : String) : Finset ℕ := {1, 2, 3}
5960 def g (_ : String) : Finset ℕ := {4, 5, 6}
6061
61- -- info: ∅ ∪ {x} ∪ id xs : Finset ℕ
62+ -- info: ∅ ∪ {x} ∪ xs : Finset ℕ
6263 #check free_union ℕ
6364
64- -- info: ∅ ∪ {x} ∪ id xs ∪ f var ∪ g var : Finset ℕ
65+ -- info: ∅ ∪ {x} ∪ xs ∪ f var ∪ g var : Finset ℕ
6566 #check free_union [f, g] ℕ
6667
67- info: ∅ ∪ id xs : Finset ℕ
68+ info: ∅ ∪ xs : Finset ℕ
6869 #check free_union (singleton := false) ℕ
6970
7071 -- info: ∅ ∪ {x} : Finset ℕ
@@ -76,6 +77,8 @@ declare_config_elab elabFreeUnionConfig FreeUnionConfig
7677-/
7778syntax (name := freeUnion) "free_union" optConfig (" [" (term,*) "]" )? term : term
7879
80+ open Qq
81+
7982set_option linter.style.emptyLine false in
8083/-- Elaborator for `free_union`. -/
8184@ [term_elab freeUnion]
@@ -86,44 +89,33 @@ def HasFresh.freeUnion : TermElab := fun stx _ => do
8689
8790 -- the type of our variables
8891 let var ← elabType var
92+ let dl ← getDecLevel var
93+ have α : Q(Type dl) := var
8994
9095 -- maps to variables
9196 let maps := maps.map (·.getElems) |>.getD #[]
92- let mut maps ← maps.mapM (flip elabTerm none)
93-
94- -- construct ∅
95- let dl ← getDecLevel var
96- let FinsetType := mkApp (mkConst ``Finset [dl]) var
97- let EmptyCollectionInst ← synthInstance (mkApp (mkConst ``EmptyCollection [dl]) FinsetType)
98- let empty :=
99- mkAppN (mkConst ``EmptyCollection.emptyCollection [dl]) #[FinsetType, EmptyCollectionInst]
97+ let mut maps ← maps.mapM (elabTerm · none)
10098
10199 -- singleton variables
102100 if cfg.singleton then
103- let SingletonInst ← synthInstance <| mkAppN (mkConst ``Singleton [dl, dl]) #[var, FinsetType]
104- let singleton_map :=
105- mkAppN (mkConst ``Singleton.singleton [dl, dl]) #[var, FinsetType, SingletonInst]
106- maps := maps.push singleton_map
101+ maps := maps.push q(Singleton.singleton : $α → Finset $α)
107102
108103 -- any finite sets
109104 if cfg.finset then
110- let id_map := mkApp (mkConst ``id [← getLevel var]) FinsetType
111- maps := maps.push id_map
105+ maps := maps.push q((·) : Finset $α → Finset $α)
112106
113- let mut finsets := #[]
107+ let mut finsets : Array Q(Finset $α) : = #[]
114108
115- for ldecl in ( ← getLCtx) do
109+ for ldecl in ← getLCtx do
116110 if !ldecl.isImplementationDetail then
117111 let local_type ← ldecl.toExpr |> inferType >=> whnf
118112 for map in maps do
119113 if let Expr.forallE _ dom _ _ := ← inferType map then
120- if (← isDefEq local_type dom) then
121- finsets := finsets.push (mkApp map ldecl.toExpr)
114+ if ← isDefEq local_type dom then
115+ finsets := finsets.push (map.betaRev #[ ldecl.toExpr] )
122116
123- -- construct a union fold
124- let UnionInst ← synthInstance (mkApp (mkConst ``Union [dl]) FinsetType)
125- let UnionFinset := mkAppN (mkConst ``Union.union [dl]) #[FinsetType, UnionInst]
126- let union := finsets.foldl (mkApp2 UnionFinset) empty
117+ let _dec : Q(DecidableEq $α) ← synthInstanceQ q(DecidableEq $α)
118+ let union := finsets.foldl (fun a b : Q(Finset $α) => q($a ∪ $b)) q(∅)
127119
128120 return union
129121 | _ => throwUnsupportedSyntax
0 commit comments