Source code for ast_toolbox.mcts.AdaptiveStressTestingRandomSeed

import copy

import gym

import ast_toolbox.mcts.RNGWrapper as RNG
from ast_toolbox.mcts.AdaptiveStressTesting import AdaptiveStressTest


[docs]class AdaptiveStressTestRS(AdaptiveStressTest): """The AST wrapper for MCTS using random seeds as actions. Parameters ---------- kwargs : Keyword arguments passed to `ast_toolbox.mcts.AdaptiveStressTesting.AdaptiveStressTest` """ def __init__(self, **kwargs): super(AdaptiveStressTestRS, self).__init__(**kwargs) self.rsg = RNG.RSG(self.params.rsg_length, self.params.init_seed) self.initial_rsg = copy.deepcopy(self.rsg)
[docs] def reset_rsg(self): """Reset the random seed generator. """ self.rsg = copy.deepcopy(self.initial_rsg)
[docs] def random_action(self): """Randomly sample an action for the rollout. Returns ---------- action : :py:class:`ast_toolbox.mcts.AdaptiveStressTestingRandomSeed.ASTRSAction` The sampled action. """ self.rsg.next() return ASTRSAction(action=copy.deepcopy(self.rsg), env=self.env)
[docs] def explore_action(self, s, tree): """Randomly sample an action for the exploration. Returns ---------- action : :py:class:`ast_toolbox.mcts.AdaptiveStressTestingRandomSeed.ASTRSAction` The sampled action. """ self.rsg.next() return ASTRSAction(action=copy.deepcopy(self.rsg), env=self.env)
[docs]class ASTRSAction: """The AST action containing the random seed. Parameters ---------- action : The random seed. env : :py:class:`ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv` The environment. """ def __init__(self, action, env): self.env = env self.action = action def __hash__(self): """The redefined hashing method. Returns ---------- hash : int The hashing result. """ return hash(self.action) def __eq__(self, other): """The redefined equal method. Returns ---------- is_equal : bool Whether the two states are equal. """ return self.action == other.action
[docs] def get(self): """Get the true action. Returns ---------- action : The true actions used in the env. """ rng_state = self.action.state # TODO: a better approch to make use of random seed of length > 1 action_seed = int(rng_state[0]) if isinstance(self.env.action_space, gym.spaces.Space): action_space = self.env.action_space # need to do this since every time call env.action_space, a new space is created action_space.seed(action_seed) true_action = action_space.sample() else: true_action = action_seed return true_action