ast_toolbox.samplers.batch_sampler module

Module for parallel sampling a batch of rollouts

class ast_toolbox.samplers.batch_sampler.BatchSampler(algo, env, n_envs=1, open_loop=True, batch_simulate=False, sim=<ast_toolbox.simulators.example_av_simulator.example_av_simulator.ExampleAVSimulator object>, reward_function=<ast_toolbox.rewards.example_av_reward.ExampleAVReward object>)[source]

Bases: garage.sampler.base.BaseSampler

Collects samples in parallel using a stateful pool of workers.

Parameters:
  • algo (garage.np.algos.base.RLAlgorithm) – The algorithm.
  • env (ast_toolbox.envs.ASTEnv) – The environment.
  • n_envs (int) – Number of parallel environments to run.
  • open_loop (bool) – True if the simulation is open-loop, meaning that AST must generate all actions ahead of time, instead of being able to output an action in sync with the simulator, getting an observation back before the next action is generated. False to get interactive control, which requires that blackbox_sim_state is also False.
  • batch_simulate (bool) – When in obtain_samples with open_loop == True, the sampler will call self.sim.batch_simulate_paths if batch_simulate is True, and self.sim.simulate if False.
  • sim (ast_toolbox.simulators.ASTSimulator) – The simulator wrapper, inheriting from ast_toolbox.simulators.ASTSimulator.
  • reward_function (ast_toolbox.rewards.ASTReward) – The reward function, inheriting from ast_toolbox.rewards.ASTReward.
  • Args – algo (garage.np.algos.RLAlgorithm): The algorithm. env (gym.Env): The environment.
obtain_samples(itr, batch_size=None, whole_paths=True)[source]

Collect samples for the given iteration number.

Parameters:
  • itr (int) – Iteration number.
  • batch_size (int, optional) – How many simulation steps to run in each epoch.
  • whole_paths (bool, optional) – Whether to return the full rollout paths data.
shutdown_worker()[source]

Terminate workers if necessary.

slice_dict(in_dict, slice_idx)[source]

Helper function to recursively parse through a dictionary of dictionaries and arrays to slice the arrays at a certain index.

Parameters:
  • in_dict (dict) – Dictionary where the values are arrays or other dictionaries that follow this stipulation.
  • slice_idx (int) – Index to slice each array at.
Returns:

dict – Dictionary where arrays at every level are sliced.

start_worker()[source]

Initialize the sampler.

ast_toolbox.samplers.batch_sampler.worker_init_tf(g)[source]

Initialize the tf.Session on a worker.

Parameters:g (garage.sampler.stateful_pool.SharedGlobal) – SharedGlobal class from garage.sampler.stateful_pool.
ast_toolbox.samplers.batch_sampler.worker_init_tf_vars(g)[source]

Initialize the policy parameters on a worker.

Parameters:g (garage.sampler.stateful_pool.SharedGlobal) – SharedGlobal class from garage.sampler.stateful_pool.