Source code for ast_toolbox.mcts.AST_MCTS

import ast_toolbox.mcts.MCTSdpw as MCTSdpw
import ast_toolbox.mcts.MDP as MDP


[docs]def rollout_getAction(ast): """Get the rollout function from ast. Parameters ---------- ast : :py:class:`ast_toolbox.mcts.AdaptiveStressTest.AdaptiveStressTesting` The AST object. """ def rollout_policy(s, tree): return ast.random_action() return rollout_policy
[docs]def explore_getAction(ast): """Get the exploration function from ast. Parameters ---------- ast : :py:class:`ast_toolbox.mcts.AdaptiveStressTest.AdaptiveStressTesting` The AST object. """ def explore_policy(s, tree): return ast.explore_action(s, tree) return explore_policy
[docs]def stress_test(ast, mcts_params, top_paths, verbose=True, return_tree=False): """Run stress test with mode 1 (search with single tree). Parameters ---------- ast : :py:class:`ast_toolbox.mcts.AdaptiveStressTest.AdaptiveStressTesting` The AST object. mcts_params: :py:class:`ast_toolbox.mcts.MCTSdpw.DPWParams` The mcts parameters. top_paths : :py:class:`ast_toolbox.mcts.BoundedPriorityQueues` The bounded priority queue to store top-rewarded trajectories. verbose : bool, optional Whether to logging test information return_tree: bool, optional Whether to return the search tree Returns ------- results : :py:class:`ast_toolbox.mcts.AdaptiveStressTest.AdaptiveStressTesting` The bounded priority queue storing top-rewarded trajectories. tree : dict The resulting searching tree. """ dpw_model = MCTSdpw.DPWModel(ast.transition_model, rollout_getAction(ast), explore_getAction(ast)) tree = MCTSdpw.DPWTree(mcts_params, dpw_model) (mcts_reward, action_seq) = MDP.simulate(tree.f.model, tree, MCTSdpw.selectAction, verbose=verbose) results = ast.top_paths if return_tree: return results, tree.s_tree else: return results
[docs]def stress_test2(ast, mcts_params, top_paths, verbose=True, return_tree=False): """Run stress test with mode 2 (search with multiple trees). Parameters ---------- ast : :py:class:`ast_toolbox.mcts.AdaptiveStressTest.AdaptiveStressTesting` The AST object. mcts_params: :py:class:`ast_toolbox.mcts.MCTSdpw.DPWParams` The mcts parameters. top_paths : :py:class:`ast_toolbox.mcts.BoundedPriorityQueues` The bounded priority queue to store top-rewarded trajectories. verbose : bool, optional Whether to logging test information return_tree: bool, optional Whether to return the search tree Returns ------- results : :py:class:`ast_toolbox.mcts.AdaptiveStressTest.AdaptiveStressTesting` The bounded priority queue storing top-rewarded trajectories. tree : dict The resulting searching tree. """ mcts_params.clear_nodes = False mcts_params.n *= ast.params.max_steps dpw_model = MCTSdpw.DPWModel(ast.transition_model, rollout_getAction(ast), explore_getAction(ast)) tree = MCTSdpw.DPWTree(mcts_params, dpw_model) s = tree.f.model.getInitialState() MCTSdpw.selectAction(tree, s, verbose=verbose) results = ast.top_paths if return_tree: return results, tree.s_tree else: return results