Skip to content

Commit 2e22d70

Browse files
committed
Refactor Stmt.choose and mkChoose to use Prob type for probabilities, update elaboration logic for choose syntax, and add validation for probability constraints.
1 parent 9a1392f commit 2e22d70

1 file changed

Lines changed: 56 additions & 11 deletions

File tree

ProgramAnalysis/ProbabilisticPrograms/PWhile.lean

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ public import Lean
33

44
namespace ProgramAnalysis.ProbabilisticPrograms
55

6+
public abbrev Prob := { p : Rat // 0 ≤ p ∧ p ≤ 1 }
7+
8+
public instance : Ord Rat where
9+
compare r1 r2 :=
10+
if r1 < r2 then .lt
11+
else if r1 = r2 then .eq
12+
else .gt
13+
614
public abbrev Var := String
715

816
public abbrev Label := Nat
@@ -89,7 +97,7 @@ public inductive Stmt
8997
| assign : Var → ArithAtom → Label → Stmt
9098
| assign? : Var → List ArithAtom → Label → Stmt
9199
| seq : Stmt → Stmt → Stmt
92-
| choose : Label → Nat → Stmt → Nat → Stmt → Stmt
100+
| choose : Label → Prob → Stmt → Prob → Stmt → Stmt
93101
| sif : BoolAtom → Label → Stmt → Stmt → Stmt
94102
| swhile : BoolAtom → Label → Stmt → Stmt
95103
deriving Repr, Ord, DecidableEq
@@ -126,7 +134,7 @@ public def Stmt.mkAssign? (x : Var) (as : List ArithAtom) : StateM Label Stmt :=
126134
public def Stmt.mkSeq (s1 s2 : StateM Label Stmt) : StateM Label Stmt := do
127135
return Stmt.seq (← s1) (← s2)
128136

129-
public def Stmt.mkChoose (p1 : Nat) (s1 : StateM Label Stmt) (p2 : Nat) (s2 : StateM Label Stmt) : StateM Label Stmt := do
137+
public def Stmt.mkChoose (p1 : Prob) (s1 : StateM Label Stmt) (p2 : Prob) (s2 : StateM Label Stmt) : StateM Label Stmt := do
130138
return Stmt.choose (← freshLabel) p1 (← s1) p2 (← s2)
131139

132140
public def Stmt.mkIf (b : BoolAtom) (thn els : StateM Label Stmt) : StateM Label Stmt := do
@@ -179,7 +187,7 @@ scoped syntax "skip" : pwhile_stmt
179187
scoped syntax ident ":=" pwhile_arith_atom : pwhile_stmt
180188
scoped syntax ident "?=" "{" pwhile_arith_atom,+ "}" : pwhile_stmt
181189
scoped syntax pwhile_stmt ";" pwhile_stmt : pwhile_stmt
182-
scoped syntax "choose" num ":" pwhile_stmt "or" num ":" pwhile_stmt "ro": pwhile_stmt
190+
scoped syntax "choose" scientific ":" pwhile_stmt "or" scientific ":" pwhile_stmt "ro": pwhile_stmt
183191
scoped syntax "if" pwhile_bool_atom "then" pwhile_stmt "else" pwhile_stmt "fi" : pwhile_stmt
184192
scoped syntax "while" pwhile_bool_atom "do" pwhile_stmt "od" : pwhile_stmt
185193
scoped syntax "(" pwhile_stmt ")" : pwhile_stmt
@@ -205,7 +213,7 @@ meta def elabOpr : Syntax → MetaM Expr
205213
| `(pwhile_op_r| >=) => return .const ``Op_r.ge []
206214
| _ => throwUnsupportedSyntax
207215

208-
meta partial def elabArithAtom : Syntax → MetaM Expr
216+
meta partial def elabArithAtom : Syntax → TermElabM Expr
209217
| `(pwhile_arith_atom| $x:ident) => mkAppM ``ArithAtom.var #[mkStrLit x.getId.toString]
210218
| `(pwhile_arith_atom| $n:num) => mkAppM ``ArithAtom.const #[mkIntLit n.getNat]
211219
| `(pwhile_arith_atom| -$n:num) => mkAppM ``ArithAtom.const #[mkIntLit (n.getNat * -1)]
@@ -217,7 +225,7 @@ meta partial def elabArithAtom : Syntax → MetaM Expr
217225
| `(pwhile_arith_atom| ($a:pwhile_arith_atom)) => elabArithAtom a
218226
| _ => throwUnsupportedSyntax
219227

220-
meta partial def elabBoolAtom : Syntax → MetaM Expr
228+
meta partial def elabBoolAtom : Syntax → TermElabM Expr
221229
| `(pwhile_bool_atom| true) => return .const ``BoolAtom.btrue []
222230
| `(pwhile_bool_atom| false) => return .const ``BoolAtom.bfalse []
223231
| `(pwhile_bool_atom| not $b:pwhile_bool_atom) => do
@@ -236,7 +244,7 @@ meta partial def elabBoolAtom : Syntax → MetaM Expr
236244
| `(pwhile_bool_atom| ($b:pwhile_bool_atom)) => elabBoolAtom b
237245
| _ => throwUnsupportedSyntax
238246

239-
meta partial def elabStmt : Syntax → MetaM Expr
247+
meta partial def elabStmt : Syntax → TermElabM Expr
240248
| `(pwhile_stmt| stop) => do
241249
mkAppM ``Stmt.mkSkip #[]
242250
| `(pwhile_stmt| skip) => do
@@ -252,10 +260,22 @@ meta partial def elabStmt : Syntax → MetaM Expr
252260
let s1Expr ← elabStmt s1
253261
let s2Expr ← elabStmt s2
254262
mkAppM ``Stmt.mkSeq #[s1Expr, s2Expr]
255-
| `(pwhile_stmt| choose $p1:num : $s1:pwhile_stmt or $p2:num : $s2:pwhile_stmt ro) => do
263+
| stx@`(pwhile_stmt| choose $p1:scientific : $s1:pwhile_stmt or $p2:scientific : $s2:pwhile_stmt ro) => do
256264
let s1 ← elabStmt s1
257265
let s2 ← elabStmt s2
258-
mkAppM ``Stmt.mkChoose #[mkNatLit p1.getNat, s1, mkNatLit p2.getNat, s2]
266+
let (m, s, e) := p1.getScientific
267+
let r1 := Rat.ofScientific m s e
268+
let (m, s, e) := p2.getScientific
269+
let r2 := Rat.ofScientific m s e
270+
if r1 < 0 || r1 > 1 then
271+
throwErrorAt p1 s!"Probability must be between 0 and 1, got {r1}"
272+
if r2 < 0 || r2 > 1 then
273+
throwErrorAt p2 s!"Probability must be between 0 and 1, got {r2}"
274+
if r1 + r2 != 1 then
275+
throwErrorAt stx s!"Probabilities must sum to 1, but got {r1} + {r2} = {r1 + r2}"
276+
let p1Expr ← Term.elabTerm (← `(⟨($p1:scientific : Rat), by grind⟩)) (mkConst ``Prob)
277+
let p2Expr ← Term.elabTerm (← `(⟨($p2:scientific : Rat), by grind⟩)) (mkConst ``Prob)
278+
mkAppM ``Stmt.mkChoose #[p1Expr, s1, p2Expr, s2]
259279
| `(pwhile_stmt| if $b:pwhile_bool_atom then $s1:pwhile_stmt else $s2:pwhile_stmt fi) => do
260280
let bExpr ← elabBoolAtom b
261281
let s1Expr ← elabStmt s1
@@ -268,7 +288,7 @@ meta partial def elabStmt : Syntax → MetaM Expr
268288
| `(pwhile_stmt| ($s:pwhile_stmt)) => elabStmt s
269289
| _ => throwUnsupportedSyntax
270290

271-
meta def elabPWhile (stx : Syntax) : MetaM Expr := do
291+
meta def elabPWhile (stx : Syntax) : TermElabM Expr := do
272292
let expr ← elabStmt stx
273293
let expr ← mkAppM ``Stmt.build #[expr]
274294
return expr
@@ -284,23 +304,48 @@ def example1 : Stmt := [pWhile|
284304
fi;
285305
while x < 4 do
286306
y ?= {0, 1, 2};
287-
choose 1 : skip or 1 : stop ro;
307+
choose 0.33 : skip or 0.67 : stop ro;
288308
x := (x + y)
289309
od;
290310
stop
291311
]
292312

313+
#check Rat
314+
293315
/--
294316
info: [x := 0]¹;
295317
if [true]² then [skip]³ else [skip]⁴ fi;
296318
while [(x < 4)]⁵ do
297319
[y ?= [0, 1, 2]]⁶;
298-
[choose]1:[skip]⁸ or 1:[skip]⁹ ro;
320+
[choose]33/100:[skip]⁸ or 67/100:[skip]⁹ ro;
299321
[x := (x + y)]¹⁰
300322
od;
301323
[skip]¹¹
302324
-/
303325
#guard_msgs in
304326
#eval IO.println example1.toString
305327

328+
def Stmt.init : Stmt → Label
329+
| .skip l => l
330+
| .stop l => l
331+
| .assign _ _ l => l
332+
| .assign? _ _ l => l
333+
| .seq s1 _ => s1.init
334+
| .choose l _ _ _ _ => l
335+
| .sif _ l _ _ => l
336+
| .swhile _ l _ => l
337+
338+
def Stmt.final : Stmt → List Label
339+
| .skip l => [l]
340+
| .stop l => [l]
341+
| .assign _ _ l => [l]
342+
| .assign? _ _ l => [l]
343+
| .seq _ s2 => s2.final
344+
| .choose _ _ s1 _ s2 => s1.final ++ s2.final
345+
| .sif _ _ s1 s2 => s1.final ++ s2.final
346+
| .swhile _ l _ => [l]
347+
348+
def Stmt.flow : Stmt → List (Label × Prob × Label) := sorry
349+
350+
306351
end ProgramAnalysis.ProbabilisticPrograms

0 commit comments

Comments
 (0)