import time
import numpy as np
[docs]class DPWParams:
"""Structure that stores the parameters for the MCTS with DPW.
Parameters
----------
d : int
The maximum searching depth.
gamma : float
The discount factor.
ec : float
The weight for the exploration bonus.
n : int
The mximum number of iterations.
k : float
The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha.
alpha : float
The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha.
clear_nodes : bool
Whether to clear redundant nodes in tree.
Set it to True for saving memoray. Set it to False to better tree plotting.
"""
def __init__(self, d, gamma, ec, n, k, alpha, clear_nodes):
self.d = d # search depth
self.gamma = gamma # discount factor
self.ec = ec # exploration constant
self.n = n # number of iterations
self.k = k # dpw parameters
self.alpha = alpha # dpw parameters
self.clear_nodes = clear_nodes
[docs]class DPWModel:
"""The model used in the tree search.
Parameters
----------
model : :py:class:`ast_toolbox.mcts.MDP.TransitionModel`
The transition model.
getAction : function
getAction(s, tree) returns the action used in rollout.
getNextAction : function
getNextAction(s, tree) returns the action used in exploration.
"""
def __init__(self, model, getAction, getNextAction):
self.model = model
self.getAction = getAction # expert action used in rollout
self.getNextAction = getNextAction # exploration strategy
[docs]class StateActionStateNode:
"""The structure storing the transition state-action-state.
"""
def __init__(self):
self.n = 0 # UInt64
self.r = 0.0 # Float64
[docs]class StateActionNode:
"""The structure representing the state-action node.
"""
def __init__(self):
self.s = {} # Dict{State,StateActionStateNode}
self.n = 0 # UInt64
self.q = 0.0 # Float64
[docs]class StateNode:
"""The structure representing the state node.
"""
def __init__(self):
self.a = {} # Dict{Action,StateActionNode}
self.n = 0 # UInt64
[docs]class DPWTree:
"""The structure storing the seaching tree.
"""
def __init__(self, p, f):
self.s_tree = {} # Dict{State,StateNode}
self.p = p # DPWParams
self.f = f # DPWModel
[docs]def saveBackwardState(old_s_tree, new_s_tree, s_current):
"""Saving the s_current as well as all its predecessors in the old_s_tree into the new_s_tree.
Parameters
----------
old_s_tree : dict
The old tree.
new_s_tree : dict
The new tree.
s_current : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState`
The current state.
Returns
----------
new_s_tree : dict
The new tree.
"""
if not (s_current in old_s_tree):
return new_s_tree
s = s_current
while s is not None:
new_s_tree[s] = old_s_tree[s]
s = s.parent
return new_s_tree
[docs]def saveForwardState(old_s_tree, new_s_tree, s):
"""Saving the s_current as well as all its successors in the old_s_tree into the new_s_tree.
Parameters
----------
old_s_tree : dict
The old tree.
new_s_tree : dict
The new tree.
s_current : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState`
The current state.
Returns
----------
new_s_tree : dict
The new tree.
"""
if not (s in old_s_tree):
return new_s_tree
new_s_tree[s] = old_s_tree[s]
for sa in old_s_tree[s].a.values():
for s1 in sa.s.keys():
saveForwardState(old_s_tree, new_s_tree, s1)
return new_s_tree
[docs]def saveState(old_s_tree, s):
"""Saving the s_current as well as all its predecessors and successors in the old_s_tree into the new_s_tree.
Parameters
----------
old_s_tree : dict
The old tree.
s : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState`
The current state.
Returns
----------
new_s_tree : dict
The new tree.
"""
new_s_tree = {}
saveBackwardState(old_s_tree, new_s_tree, s)
saveForwardState(old_s_tree, new_s_tree, s)
return new_s_tree
[docs]def selectAction(tree, s, verbose=False):
"""Run MCTS to select one action for the state s
Parameters
----------
tree : :py:class:`ast_toolbox.mcts.MCTSdpw.DPWTree`
The seach tree.
s : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState`
The current state.
verbose : bool, optional
Where to log the seaching information.
Returns
----------
action : `ast_toolbox.mcts.AdaptiveStressTesting.ASTAction`
The selected AST action.
"""
if tree.p.clear_nodes:
new_dict = saveState(tree.s_tree, s)
tree.s_tree.clear()
tree.s_tree = new_dict
depth = tree.p.d
time.time() * 1e6
for i in range(tree.p.n):
R, actions = tree.f.model.goToState(s)
R += simulate(tree, s, depth, verbose=verbose)
tree.f.model.goToState(s)
state_node = tree.s_tree[s]
explored_actions = list(state_node.a.keys())
nA = len(explored_actions)
Q = np.zeros(nA)
for i in range(nA):
Q[i] = state_node.a[explored_actions[i]].q
assert len(Q) != 0
i = np.argmax(Q)
return explored_actions[i]
[docs]def simulate(tree, s, depth, verbose=False):
"""Single run of the forward MCTS search.
Parameters
----------
tree : :py:class:`ast_toolbox.mcts.MCTSdpw.DPWTree`
The seach tree.
s : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState`
The current state.
depth : int
The maximum search depth
verbose : bool, optional
Where to log the seaching information.
Returns
----------
q : float
The estimated return.
"""
if (depth == 0) | tree.f.model.isEndState(s):
return 0.0
if not (s in tree.s_tree):
tree.s_tree[s] = StateNode()
return rollout(tree, s, depth)
tree.s_tree[s].n += 1
if len(tree.s_tree[s].a) < tree.p.k * tree.s_tree[s].n**tree.p.alpha:
# explore new action
a = tree.f.getNextAction(s, tree.s_tree)
if not (a in tree.s_tree[s].a):
tree.s_tree[s].a[a] = StateActionNode()
else:
# sample explored actions
state_node = tree.s_tree[s]
explored_actions = list(state_node.a.keys())
nA = len(explored_actions)
UCT = np.zeros(nA)
nS = state_node.n
assert nS > 0
for i in range(nA):
state_action_node = state_node.a[explored_actions[i]]
assert state_action_node.n > 0
UCT[i] = state_action_node.q + tree.p.ec * np.sqrt(np.log(nS) / float(state_action_node.n))
a = explored_actions[np.argmax(UCT)]
tree.s_tree[s].a[a].q
sp, r = tree.f.model.getNextState(s, a)
if not (sp in tree.s_tree[s].a[a].s):
tree.s_tree[s].a[a].s[sp] = StateActionStateNode()
tree.s_tree[s].a[a].s[sp].r = r
tree.s_tree[s].a[a].s[sp].n = 1
else:
tree.s_tree[s].a[a].s[sp].n += 1
q = r + tree.p.gamma * simulate(tree, sp, depth - 1)
state_action_node = tree.s_tree[s].a[a]
state_action_node.n += 1
state_action_node.q += (q - state_action_node.q) / float(state_action_node.n)
tree.s_tree[s].a[a] = state_action_node
return q
[docs]def rollout(tree, s, depth):
"""Rollout from the current state s.
Parameters
----------
tree : :py:class:`ast_toolbox.mcts.MCTSdpw.DPWTree`
The seach tree.
s : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState`
The current state.
depth : int
The maximum search depth
Returns
----------
q : float
The estimated return.
"""
if (depth == 0) | tree.f.model.isEndState(s):
return 0.0
else:
a = tree.f.getAction(s, tree.s_tree)
sp, r = tree.f.model.getNextState(s, a)
qval = (r + rollout(tree, sp, depth - 1))
return qval