ast_toolbox.algos.mcts module¶
-
class
ast_toolbox.algos.mcts.
MCTS
(env, max_path_length, ec, n_itr, k, alpha, clear_nodes, log_interval, top_paths, log_dir, gamma=1.0, stress_test_mode=2, log_tabular=True, plot_tree=False, plot_path=None, plot_format='png')[source]¶ Bases:
object
Monte Carlo Tress Search (MCTS) with double progressive widening (DPW) [1] using the env’s action space as its action space.
Parameters: - env (
ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv
.) – The environment. - max_path_length (int) – The maximum search depth.
- ec (float) – The exploration constant used in UCT equation.
- n_itr (int) – The iteration number, the total numeber of environment call is approximately n_itr*max_path_length*max_path_length.
- k (float) – The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha.
- alpha (float) – The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha.
- clear_nodes (bool) – Whether to clear redundant nodes in tree. Set it to True for saving memoray. Set it to False to better tree plotting.
- log_interval (int) – The log interval in terms of environment calls.
- top_paths (
ast_toolbox.mcts.BoundedPriorityQueues
, optional) – The bounded priority queue to store top-rewarded trajectories. - gamma (float, optional) – The discount factor.
- stress_test_mode (int, optional) – The mode of the tree search. 1 for single tree. 2 for multiple trees.
- log_tabular (bool, optional) – Whether to log the training statistics into a tabular file.
- plot_tree (bool, optional) – Whether to plot the resulting searching tree.
- plot_path (str, optional) – The storing path for the tree plot.
- plot_format (str, optional) – The storing format for the tree plot
References
[1] Lee, Ritchie, et al. “Adaptive stress testing of airborne collision avoidance systems.” 2015 IEEE/AIAA 34th Digital Avionics Systems Conference (DASC). IEEE, 2015. -
train
(runner)[source]¶ Start training.
Parameters: runner ( garage.experiment.LocalRunner
) –LocalRunner
is passed to give algorithm the access torunner.step_epochs()
, which provides services such as snapshotting and sampler control.
- env (