ast_toolbox.algos.go_explore module¶
Implementation of the Go-Explore algorithm.
-
class
ast_toolbox.algos.go_explore.
Cell
(use_score_weight=True)[source]¶ Bases:
object
A representation of a state visited during exploration.
Parameters: use_score_weight (bool) – Whether or not to scale the cell’s fitness by a function of the cell’s score -
reset_cached_property
(cached_property)[source]¶ Removes cached properties so they will be recalculated on next access.
Parameters: cached_property (str) – The cached_property key to remove from the class dict.
-
count_subscores
¶ A function of times_chosen_subscore, times_chosen_since_improved_subscore, and times_visited_subscore that is used in calculating the cell’s fitness score.
Returns: float – The count subscore of the cell.
-
fitness
¶ The fitness score of the cell. Cells are sampled with probability proportional to their fitness score.
Returns: float – The fitness score of the cell.
-
is_goal
¶ Whether or not the current cell is a goal state.
Returns: bool – Is the current cell a goal.
-
is_root
¶ Checks if the cell is the root of the tree (trajectory length is 0).
Returns: bool – Whether the cell is root or not
-
is_terminal
¶ Whether or not the current cell is a terminal state.
Returns: bool – Is the current cell terminal.
-
reward
¶ The reward obtained in the current cell.
Returns: float – The reward.
-
score
¶ The score obtained in the current cell.
Returns: float – The score.
-
score_weight
¶ A heuristic function basedon the cell’s score, and other values, to bias the rollouts towards high-scoring areas.
Returns: float – The cell’s score_weight
-
step
¶ How many steps led to the current cell.
Returns: int – Length of the trajectory.
-
times_chosen
¶ How many times the current cell has been chosen to start a rollout.
Returns: int – Number of times chosen.
-
times_chosen_since_improved
¶ How many times the current cell has been chosen to start a rollout since the last time the cell was updated with an improved score or trajectory.
Returns: int – Number of times chosen since last improved.
-
times_chosen_since_improved_subscore
¶ A function of times_chosen_since_improved that is used in calculating the cell’s times_chosen_since_improved_subscore score.
Returns: float – The times_chosen_since_improved_subscore
-
times_chosen_subscore
¶ A function of times_chosen that is used in calculating the cell’s times_chosen_subscore score.
Returns: float – The times_chosen_subscore
-
times_visited
¶ How many times the current cell has been visited during all rollouts.
Returns: int – Number of times visited.
-
times_visited_subscore
¶ A function of _times_visited that is used in calculating the cell’s times_visited_subscore score.
Returns: float – The times_visited_subscore
-
value_approx
¶ The approximate value of the current cell, based on backpropigation of previous rollouts.
Returns: float – The value approximation.
-
-
class
ast_toolbox.algos.go_explore.
CellPool
(filename='database', discount=0.99, use_score_weight=True)[source]¶ Bases:
object
A hashtree data structure containing and updating all of the cells seen during rollouts.
Parameters: - filename (str, optional) – The base name for the database files. The CellPool saves a [filename]_pool.dat and a [filename]_meta.dat.
- discount (float, optional) – Discount factor used in calculating a cell’s value approximation.
- use_score_weight (bool) – Whether or not to scale a cell’s fitness by a function of the cell’s score
-
close_pool
(cell_pool_shelf)[source]¶ Close the database that the CellPool uses to store cells.
Parameters: cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
-
d_update
(cell_pool_shelf, observation, action, trajectory, score, state, parent=None, is_terminal=False, is_goal=False, reward=-inf, chosen=0)[source]¶ Runs the update algorithm for the CellPool. The process is: 1. Create a cell from the given data. 2. Check if the cell already exists in the CellPool. 3. If the cell already exists and our version is better (higher fitness or shorter trajectory), update the existing cell. 4. If the cell already exists and our version is not better, end. 5. If the cell does not already exists, add the new cell to the CellPool
Parameters: - cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
- observation (array_like) – The observation seen in the current cell.
- action (array_like) – The action taken in the current cell.
- trajectory (array_like) – The trajectory leading to the current cell.
- score (float) – The score at the current cell.
- state (array_like) – The cloned simulation state at the current cell, used for resetting if chosen to start a rollout.
- parent (int, optional) – The hash key of the cell immediately preceding the current cell in the trajectory.
- is_terminal (bool, optional) – Whether the current cell is a terminal state.
- is_goal (bool, optional) – Whether the current cell is a goal state.
- reward (float, optional) – The reward obtained at the current cell.
- chosen (int, optional) – Whether the current cell was chosen to start the rollout.
Returns: bool – True if a new cell was added to the CellPool, False otherwise
-
load
(cell_pool_shelf)[source]¶ Load a CellPool from disk.
Parameters: cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
-
open_pool
(dbname=None, dbtype=<sphinx.ext.autodoc.importer._MockObject object>, flags=<sphinx.ext.autodoc.importer._MockObject object>, protocol=4, overwrite=False)[source]¶ Open the database that the CellPool uses to store cells.
Parameters: - dbname (string)
- dbtype (int, optional) – Specifies the type of database to open. Use enumerations provided by bsddb3.
- flags (int, optional) – Specifies the configuration of the database to open. Use enumerations provided by bsddb3.
- protocol (int, optional) – Specifies the data stream format used by pickle.
- overwrite (bool, optional) – Indicates if an existing database should be overwritten if found.
Returns: cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
-
sync_and_close_pool
(cell_pool_shelf)[source]¶ Sync and then close the database that the CellPool uses to store cells.
Parameters: cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database
-
sync_pool
(cell_pool_shelf)[source]¶ Syncs the pool, ensuring that the database on disk is up-to-date.
Parameters: cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
-
value_approx_update
(value, obs_hash, cell_pool_shelf)[source]¶ Recursively calculate a value approximation through back-propagation.
Parameters: - value (Value approximation of the previous cell.)
- obs_hash (Hash key of the current cell.)
- cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
-
meta_filename
¶ The CellPool metadata filename.
Returns: str – The CellPool metadata filename.
-
pool_filename
¶ The CellPool database filename.
Returns: str – The CellPool database filename.
-
class
ast_toolbox.algos.go_explore.
GoExplore
(db_filename, max_db_size, env, env_spec, policy, baseline, save_paths_gap=0, save_paths_path=None, overwrite_db=True, use_score_weight=True, **kwargs)[source]¶ Bases:
garage.tf.algos.batch_polopt.BatchPolopt
Implementation of the Go-Explore[1]_ algorithm that is compatible with AST[2]_. :Parameters: * db_filename (str) – The base path and name for the database files. The CellPool saves a [filename]_pool.dat and a [filename]_meta.dat.
- max_db_size (int) – Maximum allowable size (in GB) of the CellPool database. Algorithm will immediately stop and exit if this size is exceeded.
- env (
ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv
) – The environment. - env_spec (
garage.envs.EnvSpec
) – Environment specification. - policy (
garage.tf.policies.Policy
) – The policy. - baseline (
garage.np.baselines.Baseline
) – The baseline. - save_paths_gap (int, optional) – How many epochs to skip between saving out full paths. Set to 1 to save every epoch. Set to 0 to disable saving.
- save_paths_path (str, optional) – Path to the directory where paths should be saved. Set to None to disable saving.
- overwrite_db (bool, optional) – Indicates if an existing database should be overwritten if found.
- use_score_weight (bool) – Whether or not to scale the cell’s fitness by a function of the cell’s score
- kwargs – Keyword arguments passed to garage.tf.algos.BatchPolopt
References
[1] Ecoffet, Adrien, et al. “Go-explore: a new approach for hard-exploration problems.” arXiv preprint arXiv:1901.10995 (2019). https://arxiv.org/abs/1901.10995 [2] Koren, Mark, and Mykel J. Kochenderfer. “Adaptive Stress Testing without Domain Heuristics using Go-Explore.” arXiv preprint arXiv:2004.04292 (2020). https://arxiv.org/abs/2004.04292 -
downsample
(obs, step=None)[source]¶ Create a downsampled approximation of the observed simulation state.
Parameters: - obs (array_like) – The observed simulation state.
- step (int, optional) – The current iteration number
Returns: array_like – The downsampled approximation of the observed simulation state.
-
get_itr_snapshot
(itr)[source]¶ Returns all the data that should be saved in the snapshot for this iteration.
Parameters: itr (int) – The current epoch number. Returns: dict – A dict containing the current iteration number, the current policy, and the current baseline.
-
init_opt
()[source]¶ Initialize the optimization procedure. If using tensorflow, this may include declaring all the variables and compiling functions
-
optimize_policy
(itr, samples_data)[source]¶ Optimize the policy using the samples.
Parameters: - itr (int) – The current epoch number.
- samples_data (dict) – The data from the sampled rollouts.
-
train
(runner)[source]¶ Obtain samplers and start actual training for each epoch.
Parameters: runner ( garage.experiment.LocalRunner
) –LocalRunner
is passed to give algorithm the access torunner.step_epochs()
, which provides services such as snapshotting and sampler control.Returns: last_return ( ast_toolbox.algos.go_explore.Cell
) – The highest scoring cell found so far
-
train_once
(itr, paths)[source]¶ Perform one step of policy optimization given one batch of samples.
Parameters: - itr (int) – Iteration number.
- paths (list[dict]) – A list of collected paths.
Returns: best_cell (
ast_toolbox.algos.go_explore.Cell
) – The highest scoring cell found so far