ast_toolbox.samplers package

Samplers for solving AST formualted RL problems.

class ast_toolbox.samplers.ASTVectorizedSampler(algo, env, n_envs=1, open_loop=True, sim=<ast_toolbox.simulators.example_av_simulator.example_av_simulator.ExampleAVSimulator object>, reward_function=<ast_toolbox.rewards.example_av_reward.ExampleAVReward object>)[source]

Bases: garage.sampler.on_policy_vectorized_sampler.OnPolicyVectorizedSampler

A vectorized sampler for AST to handle open-loop simulators.

Garage usually genearates samples in a closed-loop process. This version of the vectorized sampler instead grabs dummy data until the full rollout specification is generated, then goes back and runs the simulate function to actually obtain results. Rewards are then calculated and the path data is corrected.

Parameters:
  • algo (garage.np.algos.base.RLAlgorithm) – The algorithm.
  • env (ast_toolbox.envs.ASTEnv) – The environment.
  • n_envs (int) – Number of parallel environments to run.
  • open_loop (bool) – True if the simulation is open-loop, meaning that AST must generate all actions ahead of time, instead of being able to output an action in sync with the simulator, getting an observation back before the next action is generated. False to get interactive control, which requires that blackbox_sim_state is also False.
  • sim (ast_toolbox.simulators.ASTSimulator) – The simulator wrapper, inheriting from ast_toolbox.simulators.ASTSimulator.
  • reward_function (ast_toolbox.rewards.ASTReward) – The reward function, inheriting from ast_toolbox.rewards.ASTReward.
obtain_samples(itr, batch_size=None, whole_paths=False)[source]

Sample the policy for new trajectories.

Parameters:
  • itr (int) – Iteration number.
  • batch_size (int) – Number of samples to be collected. If None, it will be default [algo.max_path_length * n_envs].
  • whole_paths (bool) – Whether return all the paths or not. True by default. It’s possible for the paths to have total actual sample size larger than batch_size, and will be truncated if this flag is true.
Returns:

list[dict] – A list of sampled rollout paths. Each rollout path is a dictionary with the following keys:

  • observations (numpy.ndarray)
  • actions (numpy.ndarray)
  • rewards (numpy.ndarray)
  • agent_infos (dict)
  • env_infos (dict)

slice_dict(in_dict, slice_idx)[source]

Helper function to recursively parse through a dictionary of dictionaries and arrays to slice the arrays at a certain index.

Parameters:
  • in_dict (dict) – Dictionary where the values are arrays or other dictionaries that follow this stipulation.
  • slice_idx (int) – Index to slice each array at.
Returns:

dict – Dictionary where arrays at every level are sliced.

class ast_toolbox.samplers.BatchSampler(algo, env, n_envs=1, open_loop=True, batch_simulate=False, sim=<ast_toolbox.simulators.example_av_simulator.example_av_simulator.ExampleAVSimulator object>, reward_function=<ast_toolbox.rewards.example_av_reward.ExampleAVReward object>)[source]

Bases: garage.sampler.base.BaseSampler

Collects samples in parallel using a stateful pool of workers.

Parameters:
  • algo (garage.np.algos.base.RLAlgorithm) – The algorithm.
  • env (ast_toolbox.envs.ASTEnv) – The environment.
  • n_envs (int) – Number of parallel environments to run.
  • open_loop (bool) – True if the simulation is open-loop, meaning that AST must generate all actions ahead of time, instead of being able to output an action in sync with the simulator, getting an observation back before the next action is generated. False to get interactive control, which requires that blackbox_sim_state is also False.
  • batch_simulate (bool) – When in obtain_samples with open_loop == True, the sampler will call self.sim.batch_simulate_paths if batch_simulate is True, and self.sim.simulate if False.
  • sim (ast_toolbox.simulators.ASTSimulator) – The simulator wrapper, inheriting from ast_toolbox.simulators.ASTSimulator.
  • reward_function (ast_toolbox.rewards.ASTReward) – The reward function, inheriting from ast_toolbox.rewards.ASTReward.
  • Args – algo (garage.np.algos.RLAlgorithm): The algorithm. env (gym.Env): The environment.
obtain_samples(itr, batch_size=None, whole_paths=True)[source]

Collect samples for the given iteration number.

Parameters:
  • itr (int) – Iteration number.
  • batch_size (int, optional) – How many simulation steps to run in each epoch.
  • whole_paths (bool, optional) – Whether to return the full rollout paths data.
shutdown_worker()[source]

Terminate workers if necessary.

slice_dict(in_dict, slice_idx)[source]

Helper function to recursively parse through a dictionary of dictionaries and arrays to slice the arrays at a certain index.

Parameters:
  • in_dict (dict) – Dictionary where the values are arrays or other dictionaries that follow this stipulation.
  • slice_idx (int) – Index to slice each array at.
Returns:

dict – Dictionary where arrays at every level are sliced.

start_worker()[source]

Initialize the sampler.