Source code for ast_toolbox.mcts.AdaptiveStressTesting

import pickle

import numpy as np
# import garage.misc.logger as logger
from dowel import logger
from dowel import tabular

import ast_toolbox.mcts.MDP as MDP


[docs]class ASTParams: """Structure that stores internal parameters for AST. Parameters ---------- max_steps : int, optional The maximum search depth. """ def __init__(self, max_steps, log_interval, log_tabular, log_dir=None, n_itr=100): self.max_steps = max_steps self.log_interval = log_interval self.log_tabular = log_tabular self.log_dir = log_dir self.n_itr = n_itr
[docs]class AdaptiveStressTest: """The AST wrapper for MCTS using the actions in env.action_space. Parameters ---------- p : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTParams` The AST parameters env : :py:class:`ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv`. The environment. top_paths : :py:class:`ast_toolbox.mcts.BoundedPriorityQueues`, optional The bounded priority queue to store top-rewarded trajectories. """ def __init__(self, p, env, top_paths): self.params = p self.env = env self.sim_hash = hash(0) self.transition_model = self.transition_model() self.step_count = 0 self._isterminal = False self._reward = 0.0 self.action_seq = [] self.trajectory_reward = 0.0 self.top_paths = top_paths self.iter = 0
[docs] def reset_step_count(self): """Reset the env step count. """ self.step_count = 0
[docs] def initialize(self): """Initialize training variables. Returns ---------- env_reset : The reset result from the env. """ self._isterminal = False self._reward = 0.0 self.action_seq = [] self.trajectory_reward = 0.0 return self.env.reset()
[docs] def update(self, action): """Update the environment as well as the assosiated parameters. Parameters ---------- action : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTAction` The AST action. Returns ---------- obs : :py:class:`numpy.ndarry` The observation from the env step. reward : float The reward from the env step. done : bool The terminal indicator from the env step. info : dict The env info from the env step. """ self.step_count += 1 obs, reward, done, info = self.env.step(action.get()) self._isterminal = done self._reward = reward self.action_seq.append(action) self.trajectory_reward += reward if done: self.top_paths.enqueue(self.action_seq, self.trajectory_reward, make_copy=True) self.logging() return obs, reward, done, info
[docs] def logging(self): """Logging the training information. """ if self.params.log_tabular and self.iter <= self.params.n_itr: if self.step_count % self.params.log_interval == 0: self.iter += 1 logger.log(' ') tabular.record('StepNum', self.step_count) record_num = 0 if self.params.log_dir is not None: if self.step_count == self.params.log_interval: # first time logging best_actions = [] else: with open(self.params.log_dir + '/best_actions.p', 'rb') as f: best_actions = pickle.load(f) best_actions.append(np.array([x.get() for x in self.top_paths.pq[0][0]])) with open(self.params.log_dir + '/best_actions.p', 'wb') as f: pickle.dump(best_actions, f) for (topi, path) in enumerate(self.top_paths): tabular.record('reward ' + str(topi), path[1]) record_num += 1 for topi_left in range(record_num, self.top_paths.N): tabular.record('reward ' + str(topi_left), 0) logger.log(tabular) logger.dump_all(self.step_count) tabular.clear()
[docs] def isterminal(self): """Check whether the current path is finished. Returns ---------- isterinal : bool Whether the current path is finished. """ return self._isterminal
[docs] def get_reward(self): """Get the current AST reward. Returns ---------- reward : bool The AST reward. """ return self._reward
[docs] def random_action(self): """Randomly sample an action for the rollout. Returns ---------- action : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTAction` The sampled action. """ return ASTAction(self.env.action_space.sample())
[docs] def explore_action(self, s, tree): """Randomly sample an action for the exploration. Parameters ---------- s : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState` The current state. tree : dict The searching tree. Returns ---------- action : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTAction` The sampled action. """ return ASTAction(self.env.action_space.sample())
[docs] def transition_model(self): """Generate the transition model used in MCTS. Returns ---------- transition_model : :py:class:`ast_toolbox.mcts.MDP.TransitionModel` The transition model. """ def get_initial_state(): self.t_index = 1 self.initialize() s = ASTState(self.t_index, None, None) self.sim_hash = s.hash return s def get_next_state(s0, a0): assert self.sim_hash == s0.hash self.t_index += 1 self.update(a0) s1 = ASTState(self.t_index, s0, a0) self.sim_hash = s1.hash r = self.get_reward() return s1, r def isterminal(s): assert self.sim_hash == s.hash return self.isterminal() def go_to_state(target_state): s = get_initial_state() actions = get_action_sequence(target_state) # print("go to state with actions: ",actions) R = 0.0 for a in actions: s, r = get_next_state(s, a) R += r assert s == target_state return R, actions return MDP.TransitionModel(get_initial_state, get_next_state, isterminal, self.params.max_steps, go_to_state)
[docs]class ASTState: """The AST state. Parameters ---------- t_index : int The index of the timestep. parent : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState` The parent state. action : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTAction` The action leading to this state. """ def __init__(self, t_index, parent, action): self.t_index = t_index self.parent = parent self.action = action self.hash = hash(self) def __hash__(self): """The redefined hashing method. Returns ---------- hash : int The hashing result. """ if self.parent is None: return hash((self.t_index, None, hash(self.action))) else: return hash((self.t_index, self.parent.hash, hash(self.action))) def __eq__(self, other): """The redefined equal method. Returns ---------- is_equal : bool Whether the two states are equal. """ return hash(self) == hash(other)
[docs]class ASTAction: def __init__(self, action): """The AST action. Parameters ---------- action : The true actions used in the env. """ self.action = action def __hash__(self): """The redefined hashing method. Returns ---------- hash : int The hashing result. """ return hash(tuple(self.action)) def __eq__(self, other): """The redefined equal method. Returns ---------- is_equal : bool Whether the two states are equal. """ return np.array_equal(self.action, other.action)
[docs] def get(self): """Get the true action. Returns ---------- action : The true actions used in the env. """ return self.action
[docs]def get_action_sequence(s): """Get the action sequence that leads to the state. Parameters ---------- s : :py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTState` The target state. Returns ---------- actions : list[:py:class:`ast_toolbox.mcts.AdaptiveStressTesting.ASTAction`] The action sequences leading to the target state. """ actions = [] while s.parent is not None: actions.append(s.action) s = s.parent actions = list(reversed(actions)) return actions