Source code for ast_toolbox.samplers.batch_sampler

"""Module for parallel sampling a batch of rollouts"""
import numpy as np
import tensorflow as tf
from garage.sampler.base import BaseSampler
from garage.sampler.stateful_pool import singleton_pool
from garage.sampler.utils import truncate_paths

from ast_toolbox.rewards import ExampleAVReward
from ast_toolbox.samplers import parallel_sampler
from ast_toolbox.simulators import ExampleAVSimulator


[docs]def worker_init_tf(g): """Initialize the tf.Session on a worker. Parameters ---------- g : :py:class:`garage.sampler.stateful_pool.SharedGlobal` SharedGlobal class from :py:mod:`garage.sampler.stateful_pool`. """ g.sess = tf.compat.v1.Session() g.sess.__enter__()
[docs]def worker_init_tf_vars(g): """Initialize the policy parameters on a worker. Parameters ---------- g : :py:class:`garage.sampler.stateful_pool.SharedGlobal` SharedGlobal class from :py:mod:`garage.sampler.stateful_pool`. """ g.sess.run(tf.compat.v1.global_variables_initializer())
[docs]class BatchSampler(BaseSampler): """Collects samples in parallel using a stateful pool of workers. Parameters ---------- algo : :py:class:`garage.np.algos.base.RLAlgorithm` The algorithm. env : :py:class:`ast_toolbox.envs.ASTEnv` The environment. n_envs : int Number of parallel environments to run. open_loop : bool True if the simulation is open-loop, meaning that AST must generate all actions ahead of time, instead of being able to output an action in sync with the simulator, getting an observation back before the next action is generated. False to get interactive control, which requires that `blackbox_sim_state` is also False. batch_simulate : bool When in `obtain_samples` with `open_loop == True`, the sampler will call `self.sim.batch_simulate_paths` if `batch_simulate` is True, and `self.sim.simulate` if False. sim : :py:class:`ast_toolbox.simulators.ASTSimulator` The simulator wrapper, inheriting from `ast_toolbox.simulators.ASTSimulator`. reward_function : :py:class:`ast_toolbox.rewards.ASTReward` The reward function, inheriting from `ast_toolbox.rewards.ASTReward`. Args: algo (garage.np.algos.RLAlgorithm): The algorithm. env (gym.Env): The environment. """ def __init__(self, algo, env, n_envs=1, open_loop=True, batch_simulate=False, sim=ExampleAVSimulator(), reward_function=ExampleAVReward()): """ """ # pdb.set_trace() super(BatchSampler, self).__init__(algo, env) self.n_envs = n_envs self.open_loop = open_loop self.sim = sim self.reward_function = reward_function self.open_loop = open_loop self.batch_simulate = batch_simulate
[docs] def start_worker(self): """Initialize the sampler.""" assert singleton_pool.initialized, ( 'Use singleton_pool.initialize(n_parallel) to setup workers.') if singleton_pool.n_parallel > 1: singleton_pool.run_each(worker_init_tf) parallel_sampler.populate_task(self.env, self.algo.policy) if singleton_pool.n_parallel > 1: singleton_pool.run_each(worker_init_tf_vars)
# def shutdown_worker(self): # """Shutdown worker function.""" # parallel_sampler.terminate_task(scope=self.algo.scope)
[docs] def shutdown_worker(self): """Terminate workers if necessary.""" parallel_sampler.terminate_task(scope=self.algo.scope)
[docs] def obtain_samples(self, itr, batch_size=None, whole_paths=True): """Collect samples for the given iteration number. Parameters ---------- itr : int Iteration number. batch_size : int, optional How many simulation steps to run in each epoch. whole_paths : bool, optional Whether to return the full rollout paths data. """ if not batch_size: batch_size = self.algo.max_path_length * self.n_envs # cur_params = self.algo.policy.get_param_values() cur_policy_params = self.algo.policy.get_param_values() cur_env_params = self.algo.env.get_param_values() paths = parallel_sampler.sample_paths( policy_params=cur_policy_params, max_samples=batch_size, max_path_length=self.algo.max_path_length, env_params=cur_env_params, scope=self.algo.scope, ) # TODO: Doing the path correction here means the simulations will not be parallel. # Need to make own parallel sampler and put it there to make that work if self.open_loop: if self.batch_simulate: # import pdb; pdb.set_trace() paths = self.sim.batch_simulate_paths(paths=paths, reward_function=self.reward_function) else: for path in paths: s_0 = path["observations"][0] # actions = path['env_infos']['info']['actions'] actions = path['actions'] # pdb.set_trace() end_idx, info = self.sim.simulate(actions=actions, s_0=s_0) # print('----- Back from simulate: ', end_idx) if end_idx >= 0: # pdb.set_trace() self.slice_dict(path, end_idx) rewards = self.reward_function.give_reward( action=actions[end_idx], info=self.sim.get_reward_info() ) # print('----- Back from rewards: ', rewards) # pdb.set_trace() path["rewards"][end_idx] = rewards # info[:, -1] = path["rewards"][:info.shape[0]] # path['env_infos']['cache'] = info path['env_infos']['cache'] = np.zeros_like(path["rewards"]) # import pdb; pdb.set_trace() # return paths if whole_paths else truncate_paths(paths, batch_size) if whole_paths: return paths else: paths_truncated = truncate_paths(paths, batch_size) return paths_truncated
[docs] def slice_dict(self, in_dict, slice_idx): """Helper function to recursively parse through a dictionary of dictionaries and arrays to slice \ the arrays at a certain index. Parameters ---------- in_dict : dict Dictionary where the values are arrays or other dictionaries that follow this stipulation. slice_idx : int Index to slice each array at. Returns ------- dict Dictionary where arrays at every level are sliced. """ for key, value in in_dict.items(): # pdb.set_trace() if isinstance(value, dict): in_dict[key] = self.slice_dict(value, slice_idx) else: in_dict[key][slice_idx + 1:, ...] = np.zeros_like(value[slice_idx + 1:, ...]) return in_dict