Module modules.algo.dp
Expand source code
from typing import Generator
import numpy as np
class PolicyEvaluation:
"""
Implement the policy evaluation algorithm using q-tables for epsilon-greedy policies in simple gridworlds.
"""
def __init__(self, env, policy, discount_factor, truncate_pe=False, pe_tol=1e-3):
"""
Args:
env: any environment available in `classic_rl.env`
policy: An instance of DeterministicPolicy or EpsilonGreedyPolicy from `classic_rl.policy`. If the policy is
an instance of DeterministicPolicy, then you must make sure that the terminal state can be reached from
any state; otherwise policy evaluation won't converge.
discount_factor: a float in the interval [0, 1]
truncate_pe: whether to truncate policy iteration
pe_tol: precision tolerance for policy iteration
"""
self.env = env
self.policy = policy
self.truncate_pe = truncate_pe
self.discount_factor = discount_factor
if truncate_pe:
assert pe_tol is None
else:
assert pe_tol is not None
self.pe_tol = pe_tol
self.init_q()
def init_q(self) -> None:
"""
Helper method to `self.__init__`.
Initialize a q-table of all zeros.
"""
self.q = np.zeros(self.env.action_space_shape)
def backup_v(self, s) -> float:
"""
Helper method to `self.backup_q`.
$$V_{\\pi}(s) = \\sum_{a} \\pi(a \\mid s) Q_{\\pi}(s, a)$$
"""
return np.sum([self.policy.calc_b_a_given_s(a, s) * self.q[s][a] for a in self.env.get_actions(s)])
def backup_q(self, s, a) -> float:
"""
Helper method to `self.do_policy_evaluation`.
The iterative update rule for `self.q` using the Bellman equation for action-value functions.
$$Q_{\\pi}(s, a) \\leftarrow \\sum_{s^{\\prime}, r} p(s^{\\prime}, r \\mid s, a) \\left[ r + V_{\\pi}(s^{\\prime}) \\right]$$
"""
self.env.reset()
self.env.current_coord = s
s_prime, r = self.env.step(a)
return r + self.discount_factor * self.backup_v(s_prime)
def loop_state_action_qval(self) -> Generator:
"""
Helper method to `self.do_policy_evaluation`. See code for more details.
"""
for sa, old_q in np.ndenumerate(self.q):
s, a = (sa[0], sa[1]), sa[-1]
if self.env.is_actionable(s):
yield s, a, old_q
def do_policy_evaluation(self) -> None:
"""
Helper method to self.run.
## Math behind the scene
See the docstrings for `self.backup_q` and `self.backup_v` for more details.
"""
# ========== evaluate policy ==========
while True:
error = 0
for s, a, old_q in self.loop_state_action_qval():
self.q[s][a] = self.backup_q(s, a)
error = np.max([error, np.abs(old_q - self.q[s][a])])
if self.truncate_pe:
return
if error < self.pe_tol:
return
def run(self) -> None:
"""
Run the policy evaluation algorithm.
"""
print(f"==========")
print(f"Running DP policy evaluation ...")
self.do_policy_evaluation()
print("Result: Convergence reached.")
print(f"==========")
class PolicyIteration(PolicyEvaluation):
"""
Implement the policy iteration algorithm using q-tables for epsilon-greedy policies in simple gridworlds.
"""
def __init__(self, conv_tol, **kwargs):
"""
Args:
conv_tol: precision tolerance for convergence
"""
super().__init__(**kwargs)
self.conv_tol = conv_tol
def init_q(self) -> None:
"""
Helper method to `self.__init__`.
Initialize the q-table by duplicating (not just set to equal!) self.policy's q-table.
"""
self.q = self.policy.q.copy()
def check_q_convergence(self) -> bool:
"""
Helper method to self.run.
Evaluate whether the q-table has convergenced to the optimal q-table using the Bellman optimality equation for action-value functions (BOE).
## Math behind the scene
For epsilon-greedy policies, the BOEs are intuitively defined as (definition 1 and 2):
$$V_{\\ast}(s) \\triangleq (1 - \\epsilon) \\max_a Q_{\\ast}(s, a) + \\frac{\\epsilon}{\\mid A \\mid} \\sum_a Q_{\\ast}(s, a)$$
$$Q_{\\ast}(s, a) \\triangleq \\sum_{s^{\\prime}, r} p(s^{\\prime}, r \\mid s, a) \\left[ r + \\gamma V_{\\ast}(s^{\\prime}) \\right]$$
For more information on optimal value functions, read section 3.6 of Sutton & Barto 2018.
In this implementation, we evaluate whether the LHS and the RHS of definition 2 are equal. If the maximum difference between
the LHS and the RHS over all state-action pairs is smaller than `tol`, then this function returns `True`; otherwise, this
function returns `False`.
For a state-action pair \\((s, a)\\):
* The LHS can be directly obtained using `self.q[s][a]`.
* The RHS can be calculated using `self.backup_q(s, a)`, which uses `self.backup_v(s)` as a helpful function. This is because
this implementation relies on a q-table only, and the RHS can only be calculated by substituting definition 1 into 2. Note that,
before convergence check, we must temporarily set the policy to be greedy with respect to `self.q` due to the \\(\\max\\) operation
in definition 1. Ater convergence check, we must reset the policy. See code for more details.
Args:
tol: precision tolerance
Returns:
bool: whether convergence is reached
"""
# step 1: update policy (temporarily)
policy_old_q = self.policy.q.copy()
self.policy.q = self.q.copy() # self.policy knows how to act greedily with respect to any q-table
# step 2: calculate the optimality error (the difference between the LHS and the RHS of definition 2)
optimality_error = 0
for s, a, old_q in self.loop_state_action_qval():
lhs = self.q[s][a]
rhs = self.backup_q(s, a)
optimality_error = np.max([optimality_error, np.abs(lhs - rhs)])
# step 3: reset policy
self.policy.q = policy_old_q.copy()
# step 4: check whether the optimality error falls below a tolerance level
if optimality_error < self.conv_tol:
return True
else:
return False
def do_policy_improvement(self) -> None:
"""
Helper method to self.run.
"""
self.policy.q = self.q.copy()
def run(self, max_iterations, which_tqdm) -> None:
"""
Run the policy iteration algorithm.
Args:
max_iterations: the maximum number of iterations before the algorithm is halted
which_tqdm: "terminal" or "notebook", depending on whether you are running code in a terminal or a jupyter notebook
"""
assert max_iterations >= 1
assert which_tqdm in ['terminal', 'notebook']
if which_tqdm == 'terminal':
from tqdm import tqdm
elif which_tqdm == 'notebook':
from tqdm.notebook import tqdm
print(f"Running DP policy iteration for at most {max_iterations} iterations ...")
for i in tqdm(range(1, max_iterations+1)):
self.do_policy_evaluation()
converged = self.check_q_convergence()
self.do_policy_improvement()
if converged:
print(f'Result: Convergence reached at iteration {i}')
return
print(f"Result: Convergence not reached after {i} iterations.")
Classes
class PolicyEvaluation (env, policy, discount_factor, truncate_pe=False, pe_tol=0.001)
-
Implement the policy evaluation algorithm using q-tables for epsilon-greedy policies in simple gridworlds.
Args
env
- any environment available in
classic_rl.env
policy
- An instance of DeterministicPolicy or EpsilonGreedyPolicy from
classic_rl.policy
. If the policy is an instance of DeterministicPolicy, then you must make sure that the terminal state can be reached from any state; otherwise policy evaluation won't converge. discount_factor
- a float in the interval [0, 1]
truncate_pe
- whether to truncate policy iteration
pe_tol
- precision tolerance for policy iteration
Expand source code
class PolicyEvaluation: """ Implement the policy evaluation algorithm using q-tables for epsilon-greedy policies in simple gridworlds. """ def __init__(self, env, policy, discount_factor, truncate_pe=False, pe_tol=1e-3): """ Args: env: any environment available in `classic_rl.env` policy: An instance of DeterministicPolicy or EpsilonGreedyPolicy from `classic_rl.policy`. If the policy is an instance of DeterministicPolicy, then you must make sure that the terminal state can be reached from any state; otherwise policy evaluation won't converge. discount_factor: a float in the interval [0, 1] truncate_pe: whether to truncate policy iteration pe_tol: precision tolerance for policy iteration """ self.env = env self.policy = policy self.truncate_pe = truncate_pe self.discount_factor = discount_factor if truncate_pe: assert pe_tol is None else: assert pe_tol is not None self.pe_tol = pe_tol self.init_q() def init_q(self) -> None: """ Helper method to `self.__init__`. Initialize a q-table of all zeros. """ self.q = np.zeros(self.env.action_space_shape) def backup_v(self, s) -> float: """ Helper method to `self.backup_q`. $$V_{\\pi}(s) = \\sum_{a} \\pi(a \\mid s) Q_{\\pi}(s, a)$$ """ return np.sum([self.policy.calc_b_a_given_s(a, s) * self.q[s][a] for a in self.env.get_actions(s)]) def backup_q(self, s, a) -> float: """ Helper method to `self.do_policy_evaluation`. The iterative update rule for `self.q` using the Bellman equation for action-value functions. $$Q_{\\pi}(s, a) \\leftarrow \\sum_{s^{\\prime}, r} p(s^{\\prime}, r \\mid s, a) \\left[ r + V_{\\pi}(s^{\\prime}) \\right]$$ """ self.env.reset() self.env.current_coord = s s_prime, r = self.env.step(a) return r + self.discount_factor * self.backup_v(s_prime) def loop_state_action_qval(self) -> Generator: """ Helper method to `self.do_policy_evaluation`. See code for more details. """ for sa, old_q in np.ndenumerate(self.q): s, a = (sa[0], sa[1]), sa[-1] if self.env.is_actionable(s): yield s, a, old_q def do_policy_evaluation(self) -> None: """ Helper method to self.run. ## Math behind the scene See the docstrings for `self.backup_q` and `self.backup_v` for more details. """ # ========== evaluate policy ========== while True: error = 0 for s, a, old_q in self.loop_state_action_qval(): self.q[s][a] = self.backup_q(s, a) error = np.max([error, np.abs(old_q - self.q[s][a])]) if self.truncate_pe: return if error < self.pe_tol: return def run(self) -> None: """ Run the policy evaluation algorithm. """ print(f"==========") print(f"Running DP policy evaluation ...") self.do_policy_evaluation() print("Result: Convergence reached.") print(f"==========")
Subclasses
Methods
def backup_q(self, s, a) ‑> float
-
Helper method to
self.do_policy_evaluation
.The iterative update rule for
self.q
using the Bellman equation for action-value functions. Q_{\pi}(s, a) \leftarrow \sum_{s^{\prime}, r} p(s^{\prime}, r \mid s, a) \left[ r + V_{\pi}(s^{\prime}) \right]Expand source code
def backup_q(self, s, a) -> float: """ Helper method to `self.do_policy_evaluation`. The iterative update rule for `self.q` using the Bellman equation for action-value functions. $$Q_{\\pi}(s, a) \\leftarrow \\sum_{s^{\\prime}, r} p(s^{\\prime}, r \\mid s, a) \\left[ r + V_{\\pi}(s^{\\prime}) \\right]$$ """ self.env.reset() self.env.current_coord = s s_prime, r = self.env.step(a) return r + self.discount_factor * self.backup_v(s_prime)
def backup_v(self, s) ‑> float
-
Helper method to
self.backup_q
. V_{\pi}(s) = \sum_{a} \pi(a \mid s) Q_{\pi}(s, a)Expand source code
def backup_v(self, s) -> float: """ Helper method to `self.backup_q`. $$V_{\\pi}(s) = \\sum_{a} \\pi(a \\mid s) Q_{\\pi}(s, a)$$ """ return np.sum([self.policy.calc_b_a_given_s(a, s) * self.q[s][a] for a in self.env.get_actions(s)])
def do_policy_evaluation(self) ‑> NoneType
-
Helper method to self.run.
Math behind the scene
See the docstrings for
self.backup_q
andself.backup_v
for more details.Expand source code
def do_policy_evaluation(self) -> None: """ Helper method to self.run. ## Math behind the scene See the docstrings for `self.backup_q` and `self.backup_v` for more details. """ # ========== evaluate policy ========== while True: error = 0 for s, a, old_q in self.loop_state_action_qval(): self.q[s][a] = self.backup_q(s, a) error = np.max([error, np.abs(old_q - self.q[s][a])]) if self.truncate_pe: return if error < self.pe_tol: return
def init_q(self) ‑> NoneType
-
Helper method to
self.__init__
. Initialize a q-table of all zeros.Expand source code
def init_q(self) -> None: """ Helper method to `self.__init__`. Initialize a q-table of all zeros. """ self.q = np.zeros(self.env.action_space_shape)
def loop_state_action_qval(self) ‑> Generator
-
Helper method to
self.do_policy_evaluation
. See code for more details.Expand source code
def loop_state_action_qval(self) -> Generator: """ Helper method to `self.do_policy_evaluation`. See code for more details. """ for sa, old_q in np.ndenumerate(self.q): s, a = (sa[0], sa[1]), sa[-1] if self.env.is_actionable(s): yield s, a, old_q
def run(self) ‑> NoneType
-
Run the policy evaluation algorithm.
Expand source code
def run(self) -> None: """ Run the policy evaluation algorithm. """ print(f"==========") print(f"Running DP policy evaluation ...") self.do_policy_evaluation() print("Result: Convergence reached.") print(f"==========")
class PolicyIteration (conv_tol, **kwargs)
-
Implement the policy iteration algorithm using q-tables for epsilon-greedy policies in simple gridworlds.
Args
conv_tol
- precision tolerance for convergence
Expand source code
class PolicyIteration(PolicyEvaluation): """ Implement the policy iteration algorithm using q-tables for epsilon-greedy policies in simple gridworlds. """ def __init__(self, conv_tol, **kwargs): """ Args: conv_tol: precision tolerance for convergence """ super().__init__(**kwargs) self.conv_tol = conv_tol def init_q(self) -> None: """ Helper method to `self.__init__`. Initialize the q-table by duplicating (not just set to equal!) self.policy's q-table. """ self.q = self.policy.q.copy() def check_q_convergence(self) -> bool: """ Helper method to self.run. Evaluate whether the q-table has convergenced to the optimal q-table using the Bellman optimality equation for action-value functions (BOE). ## Math behind the scene For epsilon-greedy policies, the BOEs are intuitively defined as (definition 1 and 2): $$V_{\\ast}(s) \\triangleq (1 - \\epsilon) \\max_a Q_{\\ast}(s, a) + \\frac{\\epsilon}{\\mid A \\mid} \\sum_a Q_{\\ast}(s, a)$$ $$Q_{\\ast}(s, a) \\triangleq \\sum_{s^{\\prime}, r} p(s^{\\prime}, r \\mid s, a) \\left[ r + \\gamma V_{\\ast}(s^{\\prime}) \\right]$$ For more information on optimal value functions, read section 3.6 of Sutton & Barto 2018. In this implementation, we evaluate whether the LHS and the RHS of definition 2 are equal. If the maximum difference between the LHS and the RHS over all state-action pairs is smaller than `tol`, then this function returns `True`; otherwise, this function returns `False`. For a state-action pair \\((s, a)\\): * The LHS can be directly obtained using `self.q[s][a]`. * The RHS can be calculated using `self.backup_q(s, a)`, which uses `self.backup_v(s)` as a helpful function. This is because this implementation relies on a q-table only, and the RHS can only be calculated by substituting definition 1 into 2. Note that, before convergence check, we must temporarily set the policy to be greedy with respect to `self.q` due to the \\(\\max\\) operation in definition 1. Ater convergence check, we must reset the policy. See code for more details. Args: tol: precision tolerance Returns: bool: whether convergence is reached """ # step 1: update policy (temporarily) policy_old_q = self.policy.q.copy() self.policy.q = self.q.copy() # self.policy knows how to act greedily with respect to any q-table # step 2: calculate the optimality error (the difference between the LHS and the RHS of definition 2) optimality_error = 0 for s, a, old_q in self.loop_state_action_qval(): lhs = self.q[s][a] rhs = self.backup_q(s, a) optimality_error = np.max([optimality_error, np.abs(lhs - rhs)]) # step 3: reset policy self.policy.q = policy_old_q.copy() # step 4: check whether the optimality error falls below a tolerance level if optimality_error < self.conv_tol: return True else: return False def do_policy_improvement(self) -> None: """ Helper method to self.run. """ self.policy.q = self.q.copy() def run(self, max_iterations, which_tqdm) -> None: """ Run the policy iteration algorithm. Args: max_iterations: the maximum number of iterations before the algorithm is halted which_tqdm: "terminal" or "notebook", depending on whether you are running code in a terminal or a jupyter notebook """ assert max_iterations >= 1 assert which_tqdm in ['terminal', 'notebook'] if which_tqdm == 'terminal': from tqdm import tqdm elif which_tqdm == 'notebook': from tqdm.notebook import tqdm print(f"Running DP policy iteration for at most {max_iterations} iterations ...") for i in tqdm(range(1, max_iterations+1)): self.do_policy_evaluation() converged = self.check_q_convergence() self.do_policy_improvement() if converged: print(f'Result: Convergence reached at iteration {i}') return print(f"Result: Convergence not reached after {i} iterations.")
Ancestors
Methods
def check_q_convergence(self) ‑> bool
-
Helper method to self.run. Evaluate whether the q-table has convergenced to the optimal q-table using the Bellman optimality equation for action-value functions (BOE).
Math behind the scene
For epsilon-greedy policies, the BOEs are intuitively defined as (definition 1 and 2):
V_{\ast}(s) \triangleq (1 - \epsilon) \max_a Q_{\ast}(s, a) + \frac{\epsilon}{\mid A \mid} \sum_a Q_{\ast}(s, a) Q_{\ast}(s, a) \triangleq \sum_{s^{\prime}, r} p(s^{\prime}, r \mid s, a) \left[ r + \gamma V_{\ast}(s^{\prime}) \right]
For more information on optimal value functions, read section 3.6 of Sutton & Barto 2018.
In this implementation, we evaluate whether the LHS and the RHS of definition 2 are equal. If the maximum difference between the LHS and the RHS over all state-action pairs is smaller than
tol
, then this function returnsTrue
; otherwise, this function returnsFalse
.For a state-action pair (s, a):
- The LHS can be directly obtained using
self.q[s][a]
. - The RHS can be calculated using
self.backup_q(s, a)
, which usesself.backup_v(s)
as a helpful function. This is because this implementation relies on a q-table only, and the RHS can only be calculated by substituting definition 1 into 2. Note that, before convergence check, we must temporarily set the policy to be greedy with respect toself.q
due to the \max operation in definition 1. Ater convergence check, we must reset the policy. See code for more details.
Args
tol
- precision tolerance
Returns
bool
- whether convergence is reached
Expand source code
def check_q_convergence(self) -> bool: """ Helper method to self.run. Evaluate whether the q-table has convergenced to the optimal q-table using the Bellman optimality equation for action-value functions (BOE). ## Math behind the scene For epsilon-greedy policies, the BOEs are intuitively defined as (definition 1 and 2): $$V_{\\ast}(s) \\triangleq (1 - \\epsilon) \\max_a Q_{\\ast}(s, a) + \\frac{\\epsilon}{\\mid A \\mid} \\sum_a Q_{\\ast}(s, a)$$ $$Q_{\\ast}(s, a) \\triangleq \\sum_{s^{\\prime}, r} p(s^{\\prime}, r \\mid s, a) \\left[ r + \\gamma V_{\\ast}(s^{\\prime}) \\right]$$ For more information on optimal value functions, read section 3.6 of Sutton & Barto 2018. In this implementation, we evaluate whether the LHS and the RHS of definition 2 are equal. If the maximum difference between the LHS and the RHS over all state-action pairs is smaller than `tol`, then this function returns `True`; otherwise, this function returns `False`. For a state-action pair \\((s, a)\\): * The LHS can be directly obtained using `self.q[s][a]`. * The RHS can be calculated using `self.backup_q(s, a)`, which uses `self.backup_v(s)` as a helpful function. This is because this implementation relies on a q-table only, and the RHS can only be calculated by substituting definition 1 into 2. Note that, before convergence check, we must temporarily set the policy to be greedy with respect to `self.q` due to the \\(\\max\\) operation in definition 1. Ater convergence check, we must reset the policy. See code for more details. Args: tol: precision tolerance Returns: bool: whether convergence is reached """ # step 1: update policy (temporarily) policy_old_q = self.policy.q.copy() self.policy.q = self.q.copy() # self.policy knows how to act greedily with respect to any q-table # step 2: calculate the optimality error (the difference between the LHS and the RHS of definition 2) optimality_error = 0 for s, a, old_q in self.loop_state_action_qval(): lhs = self.q[s][a] rhs = self.backup_q(s, a) optimality_error = np.max([optimality_error, np.abs(lhs - rhs)]) # step 3: reset policy self.policy.q = policy_old_q.copy() # step 4: check whether the optimality error falls below a tolerance level if optimality_error < self.conv_tol: return True else: return False
- The LHS can be directly obtained using
def do_policy_improvement(self) ‑> NoneType
-
Helper method to self.run.
Expand source code
def do_policy_improvement(self) -> None: """ Helper method to self.run. """ self.policy.q = self.q.copy()
def init_q(self) ‑> NoneType
-
Helper method to
self.__init__
. Initialize the q-table by duplicating (not just set to equal!) self.policy's q-table.Expand source code
def init_q(self) -> None: """ Helper method to `self.__init__`. Initialize the q-table by duplicating (not just set to equal!) self.policy's q-table. """ self.q = self.policy.q.copy()
def run(self, max_iterations, which_tqdm) ‑> NoneType
-
Run the policy iteration algorithm.
Args
max_iterations
- the maximum number of iterations before the algorithm is halted
which_tqdm
- "terminal" or "notebook", depending on whether you are running code in a terminal or a jupyter notebook
Expand source code
def run(self, max_iterations, which_tqdm) -> None: """ Run the policy iteration algorithm. Args: max_iterations: the maximum number of iterations before the algorithm is halted which_tqdm: "terminal" or "notebook", depending on whether you are running code in a terminal or a jupyter notebook """ assert max_iterations >= 1 assert which_tqdm in ['terminal', 'notebook'] if which_tqdm == 'terminal': from tqdm import tqdm elif which_tqdm == 'notebook': from tqdm.notebook import tqdm print(f"Running DP policy iteration for at most {max_iterations} iterations ...") for i in tqdm(range(1, max_iterations+1)): self.do_policy_evaluation() converged = self.check_q_convergence() self.do_policy_improvement() if converged: print(f'Result: Convergence reached at iteration {i}') return print(f"Result: Convergence not reached after {i} iterations.")
Inherited members