Skip to content

Commit ac86eee

Browse files
committed
Refactor imports in Example.lean and PWhile.lean for consistency and clarity
1 parent b23a546 commit ac86eee

2 files changed

Lines changed: 58 additions & 4 deletions

File tree

ProgramAnalysis/ProbabilisticPrograms/Example.lean

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import Mathlib.Data.Matrix.Basic
2-
import Mathlib.Data.Matrix.Basis
32
import Mathlib.LinearAlgebra.Matrix.Kronecker
43

54
def matrix1 : Matrix (Fin 2) (Fin 3) Nat := ![![1, 2, 3], ![4, 5, 6]]

ProgramAnalysis/ProbabilisticPrograms/PWhile.lean

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
module
22
public import Lean
3-
public import Mathlib.Data.Matrix.Basic
43
public import Mathlib.Data.Finset.Basic
5-
public import Mathlib.Data.Finset.Order
6-
4+
public import Mathlib.Data.Matrix.Basic
5+
public import Mathlib.LinearAlgebra.Matrix.Kronecker
76
namespace ProgramAnalysis.ProbabilisticPrograms
87

98
public abbrev Prob := { p : Rat // 0 ≤ p ∧ p ≤ 1 }
@@ -386,4 +385,60 @@ def P (s : Finset Nat) (f : Fin s.card → Bool) : Matrix (Fin s.card) (Fin s.ca
386385
def I (s : Finset Nat) : Matrix (Fin s.card) (Fin s.card) (Fin 2) :=
387386
fun i j => if i = j then 1 else 0
388387

388+
open Kronecker in
389+
public def tensorProduct {m n k j : Nat}
390+
(A : Matrix (Fin m) (Fin n) Nat)
391+
(B : Matrix (Fin k) (Fin j) Nat) :
392+
Matrix (Fin (m * k)) (Fin (n * j)) Nat :=
393+
Matrix.reindex finProdFinEquiv finProdFinEquiv (A ⊗ₖ B)
394+
395+
infixl:70 " ⊗ " => tensorProduct
396+
397+
def padLeft (width : Nat) (s : String) : String :=
398+
String.ofList (List.replicate (width - s.length) ' ') ++ s
399+
400+
public def Matrix.prettyPrint [ToString α] {m n : Nat} (M : Matrix (Fin m) (Fin n) α) : String :=
401+
let cells : Array (Array String) :=
402+
Array.ofFn fun i : Fin m => Array.ofFn fun j : Fin n => toString (M i j)
403+
let colWidths : Array Nat :=
404+
Array.ofFn fun j : Fin n =>
405+
cells.foldl (fun w row => max w row[j.val]!.length) 0
406+
let rows : Array String :=
407+
cells.map fun row =>
408+
"| " ++ String.intercalate " " (Array.zipWith (fun s w => padLeft w s) row colWidths).toList ++ " |"
409+
String.intercalate "\n" rows.toList
410+
411+
public instance [ToString α] {m n : Nat} : ToString (Matrix (Fin m) (Fin n) α) :=
412+
⟨Matrix.prettyPrint⟩
413+
414+
public def ArithAtom.vars : ArithAtom → List Var
415+
| .var x => [x]
416+
| .const _ => []
417+
| .op _ a1 a2 => a1.vars ++ a2.vars
418+
419+
public def BoolAtom.vars : BoolAtom → List Var
420+
| .btrue | .bfalse => []
421+
| .not b => b.vars
422+
| .op _ b1 b2 => b1.vars ++ b2.vars
423+
| .rel _ a1 a2 => a1.vars ++ a2.vars
424+
425+
public def Stmt.vars : Stmt → List Var
426+
| .stop _ | .skip _ => []
427+
| .assign x a _ => x :: a.vars
428+
| .assign? x as _ => x :: as.flatMap ArithAtom.vars
429+
| .seq s1 s2 => s1.vars ++ s2.vars
430+
| .choose _ _ s1 _ s2 => s1.vars ++ s2.vars
431+
| .sif b _ s1 s2 => b.vars ++ s1.vars ++ s2.vars
432+
| .swhile b _ s => b.vars ++ s.vars
433+
434+
public structure Prog where
435+
stmt : Stmt
436+
domains : List (Var × Finset Int)
437+
allVarsHaveDomain : ∀ x ∈ stmt.vars, (domains.lookup x).isSome
438+
439+
public def Prog.domain (p : Prog) (x : Var) : Option (Finset Int) :=
440+
p.domains.lookup x
441+
442+
443+
389444
end ProgramAnalysis.ProbabilisticPrograms

0 commit comments

Comments
 (0)