ast_toolbox.algos.backward_algorithm module¶
Backward Algorithm from Salimans and Chen.
-
class
ast_toolbox.algos.backward_algorithm.
BackwardAlgorithm
(env, policy, expert_trajectory, epochs_per_step=10, max_epochs=None, skip_until_step=0, max_path_length=500, **kwargs)[source]¶ Bases:
garage.tf.algos.ppo.PPO
Backward Algorithm from Salimans and Chen [1].
Parameters: env (
ast_toolbox.envs.go_explore_ast_env.GoExploreASTEnv
) – The environment.policy (
garage.tf.policies.Policy
) – The policy.expert_trajectory (array_like[dict]) – The expert trajectory, an array_like where each member represents a timestep in a trajectory. The array_like should be 1-D and in chronological order. Each member of the array_like is a dictionary with the following keys:
- state: The simulator state at that timestep (pre-action).
- reward: The reward at that timestep (post-action).
- observation: The simulation observation at that timestep (post-action).
- action: The action taken at that timestep.
epochs_per_step (int, optional) – Maximum number of epochs to run per step of the trajectory.
max_epochs (int, optional) – Maximum number of total epochs to run. If not set, defaults to
epochs_per_step
times the number of steps in theexpert_trajectory
.skip_until_step (int, optional) – Skip training for a certain number of steps at the start, counted backwards from the end of the trajectory. For example, if this is set to 3 for an
expert_trajectory
of length 10, training will start from step 7.max_path_length (int, optional) – Maximum length of a single rollout.
kwargs – Keyword arguments passed to garage.tf.algos.PPO
References
[1] Salimans, Tim, and Richard Chen. “Learning Montezuma’s Revenge from a Single Demonstration.” arXiv preprint arXiv:1812.03381 (2018). https://arxiv.org/abs/1812.03381 -
get_next_epoch
(runner)[source]¶ Wrapper of garage’s
runner.step_epochs()
generator to handle initialization to correct trajectory stateParameters: runner (
garage.experiment.LocalRunner
) –LocalRunner
is passed to give algorithm the access torunner.step_epochs()
, which provides services such as snapshotting and sampler control.Yields: - runner.step_itr (int) – The current epoch number.
- runner.obtain_samples(runner.step_itr) (list[dict]) – A list of sampled rollouts for the current epoch
-
set_env_to_expert_trajectory_step
()[source]¶ Updates the algorithm to use the data from
expert_trajectory
up to the current step.
-
train
(runner)[source]¶ Obtain samplers and start actual training for each epoch.
Parameters: runner ( garage.experiment.LocalRunner
) –LocalRunner
is passed to give algorithm the access torunner.step_epochs()
, which provides services such as snapshotting and sampler control.Returns: full_paths (array_like) – A list of the path data from each epoch.