ast_toolbox.algos.mcts module

class ast_toolbox.algos.mcts.MCTS(env, max_path_length, ec, n_itr, k, alpha, clear_nodes, log_interval, top_paths, log_dir, gamma=1.0, stress_test_mode=2, log_tabular=True, plot_tree=False, plot_path=None, plot_format='png')[source]

Bases: object

Monte Carlo Tress Search (MCTS) with double progressive widening (DPW) [1] using the env’s action space as its action space.

Parameters:
  • env (ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv.) – The environment.
  • max_path_length (int) – The maximum search depth.
  • ec (float) – The exploration constant used in UCT equation.
  • n_itr (int) – The iteration number, the total numeber of environment call is approximately n_itr*max_path_length*max_path_length.
  • k (float) – The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha.
  • alpha (float) – The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha.
  • clear_nodes (bool) – Whether to clear redundant nodes in tree. Set it to True for saving memoray. Set it to False to better tree plotting.
  • log_interval (int) – The log interval in terms of environment calls.
  • top_paths (ast_toolbox.mcts.BoundedPriorityQueues, optional) – The bounded priority queue to store top-rewarded trajectories.
  • gamma (float, optional) – The discount factor.
  • stress_test_mode (int, optional) – The mode of the tree search. 1 for single tree. 2 for multiple trees.
  • log_tabular (bool, optional) – Whether to log the training statistics into a tabular file.
  • plot_tree (bool, optional) – Whether to plot the resulting searching tree.
  • plot_path (str, optional) – The storing path for the tree plot.
  • plot_format (str, optional) – The storing format for the tree plot

References

[1]Lee, Ritchie, et al. “Adaptive stress testing of airborne collision avoidance systems.” 2015 IEEE/AIAA 34th Digital Avionics Systems Conference (DASC). IEEE, 2015.
init()[source]

Initiate AST internal parameters

train(runner)[source]

Start training.

Parameters:runner (garage.experiment.LocalRunner) – LocalRunner is passed to give algorithm the access to runner.step_epochs(), which provides services such as snapshotting and sampler control.