@@ -3,6 +3,14 @@ public import Lean
33
44namespace ProgramAnalysis.ProbabilisticPrograms
55
6+ public abbrev Prob := { p : Rat // 0 ≤ p ∧ p ≤ 1 }
7+
8+ public instance : Ord Rat where
9+ compare r1 r2 :=
10+ if r1 < r2 then .lt
11+ else if r1 = r2 then .eq
12+ else .gt
13+
614public abbrev Var := String
715
816public abbrev Label := Nat
@@ -89,7 +97,7 @@ public inductive Stmt
8997 | assign : Var → ArithAtom → Label → Stmt
9098 | assign? : Var → List ArithAtom → Label → Stmt
9199 | seq : Stmt → Stmt → Stmt
92- | choose : Label → Nat → Stmt → Nat → Stmt → Stmt
100+ | choose : Label → Prob → Stmt → Prob → Stmt → Stmt
93101 | sif : BoolAtom → Label → Stmt → Stmt → Stmt
94102 | swhile : BoolAtom → Label → Stmt → Stmt
95103deriving Repr, Ord, DecidableEq
@@ -126,7 +134,7 @@ public def Stmt.mkAssign? (x : Var) (as : List ArithAtom) : StateM Label Stmt :=
126134public def Stmt.mkSeq (s1 s2 : StateM Label Stmt) : StateM Label Stmt := do
127135 return Stmt.seq (← s1) (← s2)
128136
129- public def Stmt.mkChoose (p1 : Nat ) (s1 : StateM Label Stmt) (p2 : Nat ) (s2 : StateM Label Stmt) : StateM Label Stmt := do
137+ public def Stmt.mkChoose (p1 : Prob ) (s1 : StateM Label Stmt) (p2 : Prob ) (s2 : StateM Label Stmt) : StateM Label Stmt := do
130138 return Stmt.choose (← freshLabel) p1 (← s1) p2 (← s2)
131139
132140public def Stmt.mkIf (b : BoolAtom) (thn els : StateM Label Stmt) : StateM Label Stmt := do
@@ -179,7 +187,7 @@ scoped syntax "skip" : pwhile_stmt
179187scoped syntax ident ":=" pwhile_arith_atom : pwhile_stmt
180188scoped syntax ident "?=" "{" pwhile_arith_atom,+ "}" : pwhile_stmt
181189scoped syntax pwhile_stmt ";" pwhile_stmt : pwhile_stmt
182- scoped syntax "choose" num ":" pwhile_stmt "or" num ":" pwhile_stmt "ro" : pwhile_stmt
190+ scoped syntax "choose" scientific ":" pwhile_stmt "or" scientific ":" pwhile_stmt "ro" : pwhile_stmt
183191scoped syntax "if" pwhile_bool_atom "then" pwhile_stmt "else" pwhile_stmt "fi" : pwhile_stmt
184192scoped syntax "while" pwhile_bool_atom "do" pwhile_stmt "od" : pwhile_stmt
185193scoped syntax "(" pwhile_stmt ")" : pwhile_stmt
@@ -205,7 +213,7 @@ meta def elabOpr : Syntax → MetaM Expr
205213 | `(pwhile_op_r| >=) => return .const ``Op_r.ge []
206214 | _ => throwUnsupportedSyntax
207215
208- meta partial def elabArithAtom : Syntax → MetaM Expr
216+ meta partial def elabArithAtom : Syntax → TermElabM Expr
209217 | `(pwhile_arith_atom| $x:ident) => mkAppM ``ArithAtom.var #[mkStrLit x.getId.toString]
210218 | `(pwhile_arith_atom| $n:num) => mkAppM ``ArithAtom.const #[mkIntLit n.getNat]
211219 | `(pwhile_arith_atom| -$n:num) => mkAppM ``ArithAtom.const #[mkIntLit (n.getNat * -1 )]
@@ -217,7 +225,7 @@ meta partial def elabArithAtom : Syntax → MetaM Expr
217225 | `(pwhile_arith_atom| ($a:pwhile_arith_atom)) => elabArithAtom a
218226 | _ => throwUnsupportedSyntax
219227
220- meta partial def elabBoolAtom : Syntax → MetaM Expr
228+ meta partial def elabBoolAtom : Syntax → TermElabM Expr
221229 | `(pwhile_bool_atom| true ) => return .const ``BoolAtom.btrue []
222230 | `(pwhile_bool_atom| false ) => return .const ``BoolAtom.bfalse []
223231 | `(pwhile_bool_atom| not $b:pwhile_bool_atom) => do
@@ -236,7 +244,7 @@ meta partial def elabBoolAtom : Syntax → MetaM Expr
236244 | `(pwhile_bool_atom| ($b:pwhile_bool_atom)) => elabBoolAtom b
237245 | _ => throwUnsupportedSyntax
238246
239- meta partial def elabStmt : Syntax → MetaM Expr
247+ meta partial def elabStmt : Syntax → TermElabM Expr
240248 | `(pwhile_stmt| stop) => do
241249 mkAppM ``Stmt.mkSkip #[]
242250 | `(pwhile_stmt| skip) => do
@@ -252,10 +260,22 @@ meta partial def elabStmt : Syntax → MetaM Expr
252260 let s1Expr ← elabStmt s1
253261 let s2Expr ← elabStmt s2
254262 mkAppM ``Stmt.mkSeq #[s1Expr, s2Expr]
255- | `(pwhile_stmt| choose $p1:num : $s1:pwhile_stmt or $p2:num : $s2:pwhile_stmt ro) => do
263+ | stx@ `(pwhile_stmt| choose $p1:scientific : $s1:pwhile_stmt or $p2:scientific : $s2:pwhile_stmt ro) => do
256264 let s1 ← elabStmt s1
257265 let s2 ← elabStmt s2
258- mkAppM ``Stmt.mkChoose #[mkNatLit p1.getNat, s1, mkNatLit p2.getNat, s2]
266+ let (m, s, e) := p1.getScientific
267+ let r1 := Rat.ofScientific m s e
268+ let (m, s, e) := p2.getScientific
269+ let r2 := Rat.ofScientific m s e
270+ if r1 < 0 || r1 > 1 then
271+ throwErrorAt p1 s! "Probability must be between 0 and 1, got { r1} "
272+ if r2 < 0 || r2 > 1 then
273+ throwErrorAt p2 s! "Probability must be between 0 and 1, got { r2} "
274+ if r1 + r2 != 1 then
275+ throwErrorAt stx s! "Probabilities must sum to 1, but got { r1} + { r2} = { r1 + r2} "
276+ let p1Expr ← Term.elabTerm (← `(⟨($p1:scientific : Rat), by grind⟩)) (mkConst ``Prob)
277+ let p2Expr ← Term.elabTerm (← `(⟨($p2:scientific : Rat), by grind⟩)) (mkConst ``Prob)
278+ mkAppM ``Stmt.mkChoose #[p1Expr, s1, p2Expr, s2]
259279 | `(pwhile_stmt| if $b:pwhile_bool_atom then $s1:pwhile_stmt else $s2:pwhile_stmt fi) => do
260280 let bExpr ← elabBoolAtom b
261281 let s1Expr ← elabStmt s1
@@ -268,7 +288,7 @@ meta partial def elabStmt : Syntax → MetaM Expr
268288 | `(pwhile_stmt| ($s:pwhile_stmt)) => elabStmt s
269289 | _ => throwUnsupportedSyntax
270290
271- meta def elabPWhile (stx : Syntax) : MetaM Expr := do
291+ meta def elabPWhile (stx : Syntax) : TermElabM Expr := do
272292 let expr ← elabStmt stx
273293 let expr ← mkAppM ``Stmt.build #[expr]
274294 return expr
@@ -284,23 +304,48 @@ def example1 : Stmt := [pWhile|
284304 fi;
285305 while x < 4 do
286306 y ?= {0 , 1 , 2 };
287- choose 1 : skip or 1 : stop ro;
307+ choose 0 . 33 : skip or 0 . 67 : stop ro;
288308 x := (x + y)
289309 od;
290310 stop
291311]
292312
313+ #check Rat
314+
293315/--
294316info: [x := 0]¹;
295317if [ true ] ² then [ skip ] ³ else [ skip ] ⁴ fi;
296318while [(x < 4)]⁵ do
297319[y ?= [0, 1, 2]]⁶;
298- [ choose ] ⁷ 1 :[ skip ] ⁸ or 1 :[ skip ] ⁹ ro;
320+ [ choose ] ⁷ 33/100 :[ skip ] ⁸ or 67/100 :[ skip ] ⁹ ro;
299321[x := (x + y)]¹⁰
300322od;
301323[ skip ] ¹¹
302324-/
303325#guard_msgs in
304326#eval IO.println example1.toString
305327
328+ def Stmt.init : Stmt → Label
329+ | .skip l => l
330+ | .stop l => l
331+ | .assign _ _ l => l
332+ | .assign? _ _ l => l
333+ | .seq s1 _ => s1.init
334+ | .choose l _ _ _ _ => l
335+ | .sif _ l _ _ => l
336+ | .swhile _ l _ => l
337+
338+ def Stmt.final : Stmt → List Label
339+ | .skip l => [l]
340+ | .stop l => [l]
341+ | .assign _ _ l => [l]
342+ | .assign? _ _ l => [l]
343+ | .seq _ s2 => s2.final
344+ | .choose _ _ s1 _ s2 => s1.final ++ s2.final
345+ | .sif _ _ s1 s2 => s1.final ++ s2.final
346+ | .swhile _ l _ => [l]
347+
348+ def Stmt.flow : Stmt → List (Label × Prob × Label) := sorry
349+
350+
306351end ProgramAnalysis.ProbabilisticPrograms
0 commit comments