Source code for ast_toolbox.algos.mcts

import numpy as np

import ast_toolbox.mcts.AdaptiveStressTesting as AST
import ast_toolbox.mcts.AST_MCTS as AST_MCTS
import ast_toolbox.mcts.ASTSim as ASTSim
import ast_toolbox.mcts.MCTSdpw as MCTSdpw
from ast_toolbox.mcts import tree_plot


[docs]class MCTS: """Monte Carlo Tress Search (MCTS) with double progressive widening (DPW) [1]_ using the env's action space as its action space. Parameters ---------- env : :py:class:`ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv`. The environment. max_path_length: int The maximum search depth. ec : float The exploration constant used in UCT equation. n_itr: int The iteration number, the total numeber of environment call is approximately n_itr*max_path_length*max_path_length. k : float The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha. alpha : float The constraint parameter used in DPW: |N(s,a)|<=kN(s)^alpha. clear_nodes : bool Whether to clear redundant nodes in tree. Set it to True for saving memoray. Set it to False to better tree plotting. log_interval : int The log interval in terms of environment calls. top_paths : :py:class:`ast_toolbox.mcts.BoundedPriorityQueues`, optional The bounded priority queue to store top-rewarded trajectories. gamma : float, optional The discount factor. stress_test_mode : int, optional The mode of the tree search. 1 for single tree. 2 for multiple trees. log_tabular : bool, optional Whether to log the training statistics into a tabular file. plot_tree : bool, optional Whether to plot the resulting searching tree. plot_path : str, optional The storing path for the tree plot. plot_format : str, optional The storing format for the tree plot References ---------- .. [1] Lee, Ritchie, et al. "Adaptive stress testing of airborne collision avoidance systems." 2015 IEEE/AIAA 34th Digital Avionics Systems Conference (DASC). IEEE, 2015. """ def __init__( self, env, max_path_length, ec, n_itr, k, alpha, clear_nodes, log_interval, top_paths, log_dir, gamma=1.0, stress_test_mode=2, log_tabular=True, plot_tree=False, plot_path=None, plot_format='png' ): self.env = env self.stress_test_mode = stress_test_mode self.max_path_length = max_path_length self.macts_params = MCTSdpw.DPWParams(max_path_length, gamma, ec, 2 * max(n_itr * log_interval // max_path_length ** 2, 1), k, alpha, clear_nodes) self.log_interval = log_interval self.top_paths = top_paths self.log_tabular = log_tabular self.plot_tree = plot_tree self.plot_path = plot_path self.plot_format = plot_format self.policy = None self.log_dir = log_dir self.n_itr = n_itr
[docs] def init(self): """Initiate AST internal parameters """ ast_params = AST.ASTParams(self.max_path_length, self.log_interval, self.log_tabular, self.log_dir, self.n_itr) self.ast = AST.AdaptiveStressTest(p=ast_params, env=self.env, top_paths=self.top_paths)
[docs] def train(self, runner): """Start training. Parameters ---------- runner : :py:class:`garage.experiment.LocalRunner <garage:garage.experiment.LocalRunner>` ``LocalRunner`` is passed to give algorithm the access to ``runner.step_epochs()``, which provides services such as snapshotting and sampler control. """ self.init() if self.plot_tree: if self.stress_test_mode == 2: result, tree = AST_MCTS.stress_test2(self.ast, self.macts_params, self.top_paths, verbose=False, return_tree=True) else: result, tree = AST_MCTS.stress_test(self.ast, self.macts_params, self.top_paths, verbose=False, return_tree=True) else: if self.stress_test_mode == 2: result = AST_MCTS.stress_test2(self.ast, self.macts_params, self.top_paths, verbose=False, return_tree=False) else: result = AST_MCTS.stress_test(self.ast, self.macts_params, self.top_paths, verbose=False, return_tree=False) self.ast.params.log_tabular = False print("checking reward consistance") for (action_seq, reward_predict) in result: [a.get() for a in action_seq] reward, _ = ASTSim.play_sequence(self.ast, action_seq, sleeptime=0.0) assert np.isclose(reward_predict, reward) print("done") if self.plot_tree: tree_plot.plot_tree(tree, d=self.max_path_length, path=self.plot_path, format=self.plot_format)