ast_toolbox.algos.go_explore module

Implementation of the Go-Explore algorithm.

class ast_toolbox.algos.go_explore.Cell(use_score_weight=True)[source]

Bases: object

A representation of a state visited during exploration.

Parameters:use_score_weight (bool) – Whether or not to scale the cell’s fitness by a function of the cell’s score
reset_cached_property(cached_property)[source]

Removes cached properties so they will be recalculated on next access.

Parameters:cached_property (str) – The cached_property key to remove from the class dict.
count_subscores

A function of times_chosen_subscore, times_chosen_since_improved_subscore, and times_visited_subscore that is used in calculating the cell’s fitness score.

Returns:float – The count subscore of the cell.
fitness

The fitness score of the cell. Cells are sampled with probability proportional to their fitness score.

Returns:float – The fitness score of the cell.
is_goal

Whether or not the current cell is a goal state.

Returns:bool – Is the current cell a goal.
is_root

Checks if the cell is the root of the tree (trajectory length is 0).

Returns:bool – Whether the cell is root or not
is_terminal

Whether or not the current cell is a terminal state.

Returns:bool – Is the current cell terminal.
reward

The reward obtained in the current cell.

Returns:float – The reward.
score

The score obtained in the current cell.

Returns:float – The score.
score_weight

A heuristic function basedon the cell’s score, and other values, to bias the rollouts towards high-scoring areas.

Returns:float – The cell’s score_weight
step

How many steps led to the current cell.

Returns:int – Length of the trajectory.
times_chosen

How many times the current cell has been chosen to start a rollout.

Returns:int – Number of times chosen.
times_chosen_since_improved

How many times the current cell has been chosen to start a rollout since the last time the cell was updated with an improved score or trajectory.

Returns:int – Number of times chosen since last improved.
times_chosen_since_improved_subscore

A function of times_chosen_since_improved that is used in calculating the cell’s times_chosen_since_improved_subscore score.

Returns:float – The times_chosen_since_improved_subscore
times_chosen_subscore

A function of times_chosen that is used in calculating the cell’s times_chosen_subscore score.

Returns:float – The times_chosen_subscore
times_visited

How many times the current cell has been visited during all rollouts.

Returns:int – Number of times visited.
times_visited_subscore

A function of _times_visited that is used in calculating the cell’s times_visited_subscore score.

Returns:float – The times_visited_subscore
value_approx

The approximate value of the current cell, based on backpropigation of previous rollouts.

Returns:float – The value approximation.
class ast_toolbox.algos.go_explore.CellPool(filename='database', discount=0.99, use_score_weight=True)[source]

Bases: object

A hashtree data structure containing and updating all of the cells seen during rollouts.

Parameters:
  • filename (str, optional) – The base name for the database files. The CellPool saves a [filename]_pool.dat and a [filename]_meta.dat.
  • discount (float, optional) – Discount factor used in calculating a cell’s value approximation.
  • use_score_weight (bool) – Whether or not to scale a cell’s fitness by a function of the cell’s score
close_pool(cell_pool_shelf)[source]

Close the database that the CellPool uses to store cells.

Parameters:cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
d_update(cell_pool_shelf, observation, action, trajectory, score, state, parent=None, is_terminal=False, is_goal=False, reward=-inf, chosen=0)[source]

Runs the update algorithm for the CellPool. The process is: 1. Create a cell from the given data. 2. Check if the cell already exists in the CellPool. 3. If the cell already exists and our version is better (higher fitness or shorter trajectory), update the existing cell. 4. If the cell already exists and our version is not better, end. 5. If the cell does not already exists, add the new cell to the CellPool

Parameters:
  • cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
  • observation (array_like) – The observation seen in the current cell.
  • action (array_like) – The action taken in the current cell.
  • trajectory (array_like) – The trajectory leading to the current cell.
  • score (float) – The score at the current cell.
  • state (array_like) – The cloned simulation state at the current cell, used for resetting if chosen to start a rollout.
  • parent (int, optional) – The hash key of the cell immediately preceding the current cell in the trajectory.
  • is_terminal (bool, optional) – Whether the current cell is a terminal state.
  • is_goal (bool, optional) – Whether the current cell is a goal state.
  • reward (float, optional) – The reward obtained at the current cell.
  • chosen (int, optional) – Whether the current cell was chosen to start the rollout.
Returns:

bool – True if a new cell was added to the CellPool, False otherwise

delete_pool()[source]

Remove the CellPool files saved on disk.

load(cell_pool_shelf)[source]

Load a CellPool from disk.

Parameters:cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
open_pool(dbname=None, dbtype=<sphinx.ext.autodoc.importer._MockObject object>, flags=<sphinx.ext.autodoc.importer._MockObject object>, protocol=4, overwrite=False)[source]

Open the database that the CellPool uses to store cells.

Parameters:
  • dbname (string)
  • dbtype (int, optional) – Specifies the type of database to open. Use enumerations provided by bsddb3.
  • flags (int, optional) – Specifies the configuration of the database to open. Use enumerations provided by bsddb3.
  • protocol (int, optional) – Specifies the data stream format used by pickle.
  • overwrite (bool, optional) – Indicates if an existing database should be overwritten if found.
Returns:

cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.

save()[source]

Save the CellPool to disk.

sync_and_close_pool(cell_pool_shelf)[source]

Sync and then close the database that the CellPool uses to store cells.

Parameters:cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database
sync_pool(cell_pool_shelf)[source]

Syncs the pool, ensuring that the database on disk is up-to-date.

Parameters:cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
value_approx_update(value, obs_hash, cell_pool_shelf)[source]

Recursively calculate a value approximation through back-propagation.

Parameters:
  • value (Value approximation of the previous cell.)
  • obs_hash (Hash key of the current cell.)
  • cell_pool_shelf (shelve.Shelf) – A shelve.Shelf wrapping a bsddb3 database.
meta_filename

The CellPool metadata filename.

Returns:str – The CellPool metadata filename.
pool_filename

The CellPool database filename.

Returns:str – The CellPool database filename.
class ast_toolbox.algos.go_explore.GoExplore(db_filename, max_db_size, env, env_spec, policy, baseline, save_paths_gap=0, save_paths_path=None, overwrite_db=True, use_score_weight=True, **kwargs)[source]

Bases: garage.tf.algos.batch_polopt.BatchPolopt

Implementation of the Go-Explore[1]_ algorithm that is compatible with AST[2]_. :Parameters: * db_filename (str) – The base path and name for the database files. The CellPool saves a [filename]_pool.dat and a [filename]_meta.dat.

  • max_db_size (int) – Maximum allowable size (in GB) of the CellPool database. Algorithm will immediately stop and exit if this size is exceeded.
  • env (ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv) – The environment.
  • env_spec (garage.envs.EnvSpec) – Environment specification.
  • policy (garage.tf.policies.Policy) – The policy.
  • baseline (garage.np.baselines.Baseline) – The baseline.
  • save_paths_gap (int, optional) – How many epochs to skip between saving out full paths. Set to 1 to save every epoch. Set to 0 to disable saving.
  • save_paths_path (str, optional) – Path to the directory where paths should be saved. Set to None to disable saving.
  • overwrite_db (bool, optional) – Indicates if an existing database should be overwritten if found.
  • use_score_weight (bool) – Whether or not to scale the cell’s fitness by a function of the cell’s score
  • kwargs – Keyword arguments passed to garage.tf.algos.BatchPolopt

References

[1]Ecoffet, Adrien, et al. “Go-explore: a new approach for hard-exploration problems.” arXiv preprint arXiv:1901.10995 (2019). https://arxiv.org/abs/1901.10995
[2]Koren, Mark, and Mykel J. Kochenderfer. “Adaptive Stress Testing without Domain Heuristics using Go-Explore.” arXiv preprint arXiv:2004.04292 (2020). https://arxiv.org/abs/2004.04292
downsample(obs, step=None)[source]

Create a downsampled approximation of the observed simulation state.

Parameters:
  • obs (array_like) – The observed simulation state.
  • step (int, optional) – The current iteration number
Returns:

array_like – The downsampled approximation of the observed simulation state.

get_itr_snapshot(itr)[source]

Returns all the data that should be saved in the snapshot for this iteration.

Parameters:itr (int) – The current epoch number.
Returns:dict – A dict containing the current iteration number, the current policy, and the current baseline.
init_opt()[source]

Initialize the optimization procedure. If using tensorflow, this may include declaring all the variables and compiling functions

optimize_policy(itr, samples_data)[source]

Optimize the policy using the samples.

Parameters:
  • itr (int) – The current epoch number.
  • samples_data (dict) – The data from the sampled rollouts.
train(runner)[source]

Obtain samplers and start actual training for each epoch.

Parameters:runner (garage.experiment.LocalRunner) – LocalRunner is passed to give algorithm the access to runner.step_epochs(), which provides services such as snapshotting and sampler control.
Returns:last_return (ast_toolbox.algos.go_explore.Cell) – The highest scoring cell found so far
train_once(itr, paths)[source]

Perform one step of policy optimization given one batch of samples.

Parameters:
  • itr (int) – Iteration number.
  • paths (list[dict]) – A list of collected paths.
Returns:

best_cell (ast_toolbox.algos.go_explore.Cell) – The highest scoring cell found so far