|
1 | 1 | import numpy as np |
2 | 2 |
|
3 | | -def policy_evaluation(S,A,P,R,pi,gamma=1.0,theta=1e-10): |
4 | | - nS,nA,_=P.shape; V=np.zeros(nS) |
| 3 | +def policy_evaluation(S, A, P, R, pi, gamma=1.0, theta=1e-10): |
| 4 | + """ |
| 5 | + Tabular policy evaluation for general R(s,a,s'). |
| 6 | + Inputs: |
| 7 | + - S: list of states |
| 8 | + - A: list of actions |
| 9 | + - P: [|S|, |A|, |S'|] transition probabilities |
| 10 | + - R: [|S|, |A|, |S'|] rewards |
| 11 | + - pi: [|S|, |A|] policy (row-stochastic; can be deterministic one-hot) |
| 12 | + - gamma: discount factor |
| 13 | + - theta: convergence threshold (max delta) |
| 14 | + Returns: |
| 15 | + - V: np.ndarray of shape [|S|] |
| 16 | + """ |
| 17 | + nS, nA, nSp = P.shape |
| 18 | + assert nS == len(S) and nA == len(A) and nSp == nS |
| 19 | + assert pi.shape == (nS, nA) |
| 20 | + |
| 21 | + V = np.zeros(nS, dtype=float) |
5 | 22 | while True: |
6 | | - delta=0; V_new=np.zeros_like(V) |
| 23 | + delta = 0.0 |
| 24 | + V_new = np.zeros_like(V) |
7 | 25 | for s in range(nS): |
8 | | - val=0 |
| 26 | + val = 0.0 |
9 | 27 | for a in range(nA): |
10 | | - if pi[s,a]==0: continue |
11 | | - val+=pi[s,a]*np.sum(P[s,a,:]*(R[s,a,:]+gamma*V)) |
12 | | - V_new[s]=val; delta=max(delta,abs(V_new[s]-V[s])) |
13 | | - V=V_new |
14 | | - if delta<theta: break |
| 28 | + p_sa = pi[s, a] |
| 29 | + if p_sa == 0.0: |
| 30 | + continue |
| 31 | + val += p_sa * np.sum(P[s, a, :] * (R[s, a, :] + gamma * V)) |
| 32 | + V_new[s] = val |
| 33 | + delta = max(delta, abs(V_new[s] - V[s])) |
| 34 | + V = V_new |
| 35 | + if delta < theta: |
| 36 | + break |
15 | 37 | return V |
16 | 38 |
|
17 | | -def q_from_v(S,A,P,R,V,gamma=1.0): |
18 | | - nS,nA,_=P.shape; Q=np.zeros((nS,nA)) |
| 39 | +def q_from_v(S, A, P, R, V, gamma=1.0): |
| 40 | + nS, nA, _ = P.shape |
| 41 | + Q = np.zeros((nS, nA), dtype=float) |
19 | 42 | for s in range(nS): |
20 | 43 | for a in range(nA): |
21 | | - Q[s,a]=np.sum(P[s,a,:]*(R[s,a,:]+gamma*V)) |
| 44 | + Q[s, a] = np.sum(P[s, a, :] * (R[s, a, :] + gamma * V)) |
22 | 45 | return Q |
0 commit comments