Skip to content

training

training

Training infrastructure: synchronous multi-env trainer, configs, return estimators, and validation strategies.

The trainer + base config that every concrete agent extends:

  • :class:SyncMultiEnvTrainer — the main training loop.
  • :class:TrainerConfig — base hyperparameter dataclass.

Per-trainer configs (PPOTrainerConfig, RNDTrainerConfig, ...) live next to their trainer class in rlib/<Agent>/trainer.py.

Internals exposed for advanced use / testing:

  • :class:Returns — enum of return / advantage estimators.
  • :func:GAE, :func:lambda_return, :func:nstep_return — the underlying free functions.
  • :class:Validator, :class:AsyncValidator, :class:SyncValidator, :func:make_validator — validation strategy implementations.

TrainerConfig dataclass

TrainerConfig(train_mode: TrainMode = TrainMode.NSTEP, returns: Returns = Returns.NSTEP, total_steps: int = 50000000, nsteps: int = 5, gamma: float = 0.99, lambda_: float = 0.95, validate_freq: int = 1000000, num_val_episodes: int = 50, max_val_steps: int = 10000, log_dir: str = 'logs/', model_dir: str = 'models/', save_freq: int = 0, log_scalars: bool = True, update_target_freq: int = 0, render_freq: int = 0)

All hyperparameters for :class:rlib.training.SyncMultiEnvTrainer.

Attributes:

Name Type Description
train_mode TrainMode

Whether to dispatch to the multi-step (:attr:TrainMode.NSTEP) or one-step (:attr:TrainMode.ONESTEP) training loop.

returns Returns

Return / advantage estimator (:class:Returns enum).

total_steps int

Total environment steps across all parallel envs.

nsteps int

Length of each n-step rollout.

gamma float

Discount factor.

lambda_ float

GAE / λ-return weighting (ignored by Returns.NSTEP).

validate_freq int

Env steps between validation passes; 0 disables validation.

num_val_episodes int

Episodes averaged per validation pass.

max_val_steps int

Per-episode step cap during validation.

log_dir str

Directory for tensorboard scalars.

model_dir str

Directory for checkpoints.

save_freq int

Env steps between checkpoints; 0 disables saving.

log_scalars bool

Whether to write tensorboard scalars at all.

update_target_freq int

Env steps between target-net syncs (off-policy agents only); 0 disables.

render_freq int

Multiple of validate_freq between renders; 0 disables rendering.

TrainMode

Bases: StrEnum

Whether the trainer dispatches to _train_nstep or _train_onestep.

Returns

Bases: Enum

Return / advantage estimators.

The enum name is the canonical CLI / log string ("NSTEP", "GAE", "LAMBDA"). The value is the wrapped callable, so members are directly callable::

targets = Returns.GAE(rewards, values, last_values, dones, 0.99, 0.95)

The :func:enum.member wrapper around each value is required so Python's enum metaclass treats the function as an enum member rather than as an instance method bound on the class.

SyncMultiEnvTrainer

SyncMultiEnvTrainer(envs: BatchEnv | DummyBatchEnv, agent: Agent, val_envs: list | BatchEnv | DummyBatchEnv, config: TrainerConfig)

Synchronous multi-env training framework for any :class:rlib.networks.Model.

Build a synchronous multi-env training loop.

Parameters:

Name Type Description Default
envs BatchEnv | DummyBatchEnv

training environments (BatchEnv or DummyBatchEnv).

required
agent Agent

an :class:rlib.agent.Agent subclass.

required
val_envs list | BatchEnv | DummyBatchEnv

validation envs — a list (uses threading), a BatchEnv (multiprocessing), or a DummyBatchEnv (in-process).

required
config TrainerConfig

a :class:TrainerConfig carrying all training hyperparameters.

required
Source code in rlib/training/trainer.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    envs: BatchEnv | DummyBatchEnv,
    agent: Agent,
    val_envs: list | BatchEnv | DummyBatchEnv,
    config: TrainerConfig,
) -> None:
    '''Build a synchronous multi-env training loop.

    Args:
        envs: training environments (``BatchEnv`` or ``DummyBatchEnv``).
        agent: an :class:`rlib.agent.Agent` subclass.
        val_envs: validation envs — a ``list`` (uses threading),
            a ``BatchEnv`` (multiprocessing), or a ``DummyBatchEnv``
            (in-process).
        config: a :class:`TrainerConfig` carrying all training
            hyperparameters.
    '''
    self.config = config

    self.env = envs
    self.validator = make_validator(val_envs)
    assert config.num_val_episodes >= len(val_envs), (
        f'number of validation epsiodes {config.num_val_episodes} must be greater than or '
        f'equal to the number of validation envs {len(val_envs)}'
    )
    self.num_envs = len(envs)
    self.env_id = envs.spec.id
    self.val_envs = val_envs
    self.agent = agent

    # Mirror config fields onto ``self`` for ergonomics — most of
    # the legacy training-loop code reads ``self.gamma`` etc. directly.
    self.train_mode = config.train_mode
    self.total_steps = config.total_steps
    self.nsteps = config.nsteps
    self.returns = config.returns
    self.gamma = config.gamma
    self.lambda_ = config.lambda_
    self.validate_freq = config.validate_freq
    self.num_val_episodes = config.num_val_episodes
    self.val_steps = config.max_val_steps
    self.save_freq = config.save_freq
    self.render_freq = config.render_freq
    self.target_freq = config.update_target_freq
    self.log_scalars = config.log_scalars
    self.log_dir = config.log_dir
    self.model_dir = config.model_dir

    self.lock = threading.Lock()
    # Kept for backwards-compatibility: agents with custom recurrent
    # validation loops (A2C-LSTM, UNREAL-LSTM, VIN) push per-episode
    # scores onto this list under ``self.lock``.
    self.validate_rewards: list[Any] = []
    self.s = 0  # number of saves made
    self.t = 1  # number of updates done
    self.states = self.env.reset()

    if config.log_scalars:
        self.train_log_dir = config.log_dir + '/train'
        self.train_writer = SummaryWriter(self.train_log_dir)
        self._log_hyperparameters()

    if not os.path.exists(self.model_dir) and config.save_freq > 0:
        os.makedirs(self.model_dir)

rollout

rollout() -> Any

Collect self.nsteps of experience and return whatever the agent's training loop expects.

Source code in rlib/training/trainer.py
173
174
175
def rollout(self) -> Any:
    """Collect ``self.nsteps`` of experience and return whatever the agent's training loop expects."""
    raise NotImplementedError(f'{type(self).__name__} does not implement rollout')

update_target

update_target() -> None

Hook called every update_target_freq steps. No-op by default.

Off-policy agents (e.g. DDQN) override this to copy weights from self.agent to a target network. Pure on-policy agents leave update_target_freq=0 and never call it.

Source code in rlib/training/trainer.py
304
305
306
307
308
309
310
311
def update_target(self) -> None:
    """Hook called every ``update_target_freq`` steps. No-op by default.

    Off-policy agents (e.g. DDQN) override this to copy weights from
    ``self.agent`` to a target network. Pure on-policy agents leave
    ``update_target_freq=0`` and never call it.
    """
    raise NotImplementedError(f'{type(self).__name__} does not implement update_target')

get_action

get_action(state: ndarray) -> Any

Hook used by the validator to pick an action during evaluation.

Concrete trainers must override this if validation is enabled (it's called by every :class:~rlib.training.validation.Validator).

Source code in rlib/training/trainer.py
322
323
324
325
326
327
328
329
330
331
def get_action(self, state: np.ndarray) -> Any:
    """Hook used by the validator to pick an action during evaluation.

    Concrete trainers must override this if validation is enabled
    (it's called by every :class:`~rlib.training.validation.Validator`).
    """
    raise NotImplementedError(
        'get_action method is required when validation is enabled, '
        'check that this is implemented properly'
    )

AsyncValidator

AsyncValidator(envs: list)

Validate against a list of envs using one daemon thread per env.

Each thread runs its share of episodes and pushes the per-episode total reward into a shared list (guarded by self._lock). After every thread joins, the mean across all collected scores is returned.

Source code in rlib/training/validation.py
70
71
72
73
74
75
def __init__(self, envs: list) -> None:
    if not isinstance(envs, list):
        raise TypeError(f"AsyncValidator expects a list of envs, got {type(envs).__name__}")
    self.envs = envs
    self._lock = threading.Lock()
    self._scores: list[float] = []

SyncValidator

SyncValidator(envs: BatchEnv | DummyBatchEnv)

Validate against a single batched env (BatchEnv / DummyBatchEnv).

Source code in rlib/training/validation.py
145
146
147
def __init__(self, envs: BatchEnv | DummyBatchEnv) -> None:
    self.envs = envs
    self.num_envs = len(envs)

Validator

Bases: Protocol

Strategy for running validation episodes and returning a mean score.

run

run(get_action: ActionFn, num_episodes: int, max_steps: int, render: bool = False) -> float

Run num_episodes of validation; return the mean total reward.

Source code in rlib/training/validation.py
46
47
48
49
50
51
52
53
def run(
    self,
    get_action: ActionFn,
    num_episodes: int,
    max_steps: int,
    render: bool = False,
) -> float:
    """Run ``num_episodes`` of validation; return the mean total reward."""

GAE

GAE(rewards: ndarray, values: ndarray, last_values: ndarray, dones: ndarray, gamma: float = 0.99, lambda_: float = 0.95, clip: bool = False) -> np.ndarray

Generalised Advantage Estimation (Schulman et al. 2015).

Returns the advantage sequence (A_t); the value targets are recovered as A_t + V_t.

Source code in rlib/training/returns.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def GAE(
    rewards: np.ndarray,
    values: np.ndarray,
    last_values: np.ndarray,
    dones: np.ndarray,
    gamma: float = 0.99,
    lambda_: float = 0.95,
    clip: bool = False,
) -> np.ndarray:
    """Generalised Advantage Estimation (Schulman et al. 2015).

    Returns the *advantage* sequence (``A_t``); the value targets are
    recovered as ``A_t + V_t``.
    """
    if clip:
        rewards = np.clip(rewards, -1, 1)
    Adv = np.zeros_like(rewards)
    Adv[-1] = rewards[-1] + gamma * last_values * (1 - dones[-1]) - values[-1]
    T = len(rewards)
    for t in reversed(range(T - 1)):
        delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
        Adv[t] = delta + gamma * lambda_ * Adv[t + 1] * (1 - dones[t])
    return Adv

lambda_return

lambda_return(rewards: ndarray, values: ndarray, last_values: ndarray, dones: ndarray, gamma: float = 0.99, lambda_: float = 0.8, clip: bool = False) -> np.ndarray

λ-return :math:R^\lambda_t = r_t + \gamma((1-\lambda) V_{t+1} + \lambda R^\lambda_{t+1}).

With lambda_ == 1.0 collapses to the n-step return; with lambda_ == 0.0 collapses to one-step TD.

Source code in rlib/training/returns.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def lambda_return(
    rewards: np.ndarray,
    values: np.ndarray,
    last_values: np.ndarray,
    dones: np.ndarray,
    gamma: float = 0.99,
    lambda_: float = 0.8,
    clip: bool = False,
) -> np.ndarray:
    r"""λ-return :math:`R^\lambda_t = r_t + \gamma((1-\lambda) V_{t+1} + \lambda R^\lambda_{t+1})`.

    With ``lambda_ == 1.0`` collapses to the n-step return; with
    ``lambda_ == 0.0`` collapses to one-step TD.
    """
    if clip:
        rewards = np.clip(rewards, -1, 1)
    T = len(rewards)
    R = np.zeros_like(rewards)
    R[-1] = last_values * (1 - dones[-1])
    for t in reversed(range(T - 1)):
        R[t] = rewards[t] + gamma * (lambda_ * R[t + 1] + (1.0 - lambda_) * values[t + 1]) * (
            1 - dones[t]
        )
    return R

nstep_return

nstep_return(rewards: ndarray, last_values: ndarray, dones: ndarray, gamma: float = 0.99, clip: bool = False) -> np.ndarray

N-step bootstrapped return :math:R_t = r_t + \gamma R_{t+1}.

The recursion is reset to zero whenever dones[t] is set so the next episode's rewards don't bleed into the current one.

Source code in rlib/training/returns.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def nstep_return(
    rewards: np.ndarray,
    last_values: np.ndarray,
    dones: np.ndarray,
    gamma: float = 0.99,
    clip: bool = False,
) -> np.ndarray:
    r"""N-step bootstrapped return :math:`R_t = r_t + \gamma R_{t+1}`.

    The recursion is reset to zero whenever ``dones[t]`` is set so the
    next episode's rewards don't bleed into the current one.
    """
    if clip:
        rewards = np.clip(rewards, -1, 1)

    T = len(rewards)
    R = np.zeros_like(rewards)
    R[-1] = last_values * (1 - dones[-1])

    for i in reversed(range(T - 1)):
        # restart score if done as BatchEnv automatically resets after end of episode
        R[i] = rewards[i] + gamma * R[i + 1] * (1 - dones[i])

    return R

make_validator

make_validator(val_envs: list | BatchEnv | DummyBatchEnv) -> Validator

Pick the appropriate :class:Validator for the given env(s).

Source code in rlib/training/validation.py
183
184
185
186
187
def make_validator(val_envs: list | BatchEnv | DummyBatchEnv) -> Validator:
    """Pick the appropriate :class:`Validator` for the given env(s)."""
    if isinstance(val_envs, list):
        return AsyncValidator(val_envs)
    return SyncValidator(val_envs)