training¶
training ¶
Training infrastructure: synchronous multi-env trainer, configs, return estimators, and validation strategies.
The trainer + base config that every concrete agent extends:
- :class:
SyncMultiEnvTrainer— the main training loop. - :class:
TrainerConfig— base hyperparameter dataclass.
Per-trainer configs (PPOTrainerConfig, RNDTrainerConfig, ...)
live next to their trainer class in rlib/<Agent>/trainer.py.
Internals exposed for advanced use / testing:
- :class:
Returns— enum of return / advantage estimators. - :func:
GAE, :func:lambda_return, :func:nstep_return— the underlying free functions. - :class:
Validator, :class:AsyncValidator, :class:SyncValidator, :func:make_validator— validation strategy implementations.
TrainerConfig
dataclass
¶
TrainerConfig(train_mode: TrainMode = TrainMode.NSTEP, returns: Returns = Returns.NSTEP, total_steps: int = 50000000, nsteps: int = 5, gamma: float = 0.99, lambda_: float = 0.95, validate_freq: int = 1000000, num_val_episodes: int = 50, max_val_steps: int = 10000, log_dir: str = 'logs/', model_dir: str = 'models/', save_freq: int = 0, log_scalars: bool = True, update_target_freq: int = 0, render_freq: int = 0)
All hyperparameters for :class:rlib.training.SyncMultiEnvTrainer.
Attributes:
| Name | Type | Description |
|---|---|---|
train_mode |
TrainMode
|
Whether to dispatch to the multi-step
(:attr: |
returns |
Returns
|
Return / advantage estimator (:class: |
total_steps |
int
|
Total environment steps across all parallel envs. |
nsteps |
int
|
Length of each n-step rollout. |
gamma |
float
|
Discount factor. |
lambda_ |
float
|
GAE / λ-return weighting (ignored by |
validate_freq |
int
|
Env steps between validation passes; |
num_val_episodes |
int
|
Episodes averaged per validation pass. |
max_val_steps |
int
|
Per-episode step cap during validation. |
log_dir |
str
|
Directory for tensorboard scalars. |
model_dir |
str
|
Directory for checkpoints. |
save_freq |
int
|
Env steps between checkpoints; |
log_scalars |
bool
|
Whether to write tensorboard scalars at all. |
update_target_freq |
int
|
Env steps between target-net syncs (off-policy
agents only); |
render_freq |
int
|
Multiple of |
TrainMode ¶
Bases: StrEnum
Whether the trainer dispatches to _train_nstep or _train_onestep.
Returns ¶
Bases: Enum
Return / advantage estimators.
The enum name is the canonical CLI / log string ("NSTEP",
"GAE", "LAMBDA"). The value is the wrapped callable, so
members are directly callable::
targets = Returns.GAE(rewards, values, last_values, dones, 0.99, 0.95)
The :func:enum.member wrapper around each value is required so
Python's enum metaclass treats the function as an enum member rather
than as an instance method bound on the class.
SyncMultiEnvTrainer ¶
SyncMultiEnvTrainer(envs: BatchEnv | DummyBatchEnv, agent: Agent, val_envs: list | BatchEnv | DummyBatchEnv, config: TrainerConfig)
Synchronous multi-env training framework for any :class:rlib.networks.Model.
Build a synchronous multi-env training loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
envs
|
BatchEnv | DummyBatchEnv
|
training environments ( |
required |
agent
|
Agent
|
an :class: |
required |
val_envs
|
list | BatchEnv | DummyBatchEnv
|
validation envs — a |
required |
config
|
TrainerConfig
|
a :class: |
required |
Source code in rlib/training/trainer.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | |
rollout ¶
rollout() -> Any
Collect self.nsteps of experience and return whatever the agent's training loop expects.
Source code in rlib/training/trainer.py
173 174 175 | |
update_target ¶
update_target() -> None
Hook called every update_target_freq steps. No-op by default.
Off-policy agents (e.g. DDQN) override this to copy weights from
self.agent to a target network. Pure on-policy agents leave
update_target_freq=0 and never call it.
Source code in rlib/training/trainer.py
304 305 306 307 308 309 310 311 | |
get_action ¶
get_action(state: ndarray) -> Any
Hook used by the validator to pick an action during evaluation.
Concrete trainers must override this if validation is enabled
(it's called by every :class:~rlib.training.validation.Validator).
Source code in rlib/training/trainer.py
322 323 324 325 326 327 328 329 330 331 | |
AsyncValidator ¶
AsyncValidator(envs: list)
Validate against a list of envs using one daemon thread per env.
Each thread runs its share of episodes and pushes the per-episode
total reward into a shared list (guarded by self._lock). After
every thread joins, the mean across all collected scores is
returned.
Source code in rlib/training/validation.py
70 71 72 73 74 75 | |
SyncValidator ¶
SyncValidator(envs: BatchEnv | DummyBatchEnv)
Validate against a single batched env (BatchEnv / DummyBatchEnv).
Source code in rlib/training/validation.py
145 146 147 | |
Validator ¶
Bases: Protocol
Strategy for running validation episodes and returning a mean score.
run ¶
run(get_action: ActionFn, num_episodes: int, max_steps: int, render: bool = False) -> float
Run num_episodes of validation; return the mean total reward.
Source code in rlib/training/validation.py
46 47 48 49 50 51 52 53 | |
GAE ¶
GAE(rewards: ndarray, values: ndarray, last_values: ndarray, dones: ndarray, gamma: float = 0.99, lambda_: float = 0.95, clip: bool = False) -> np.ndarray
Generalised Advantage Estimation (Schulman et al. 2015).
Returns the advantage sequence (A_t); the value targets are
recovered as A_t + V_t.
Source code in rlib/training/returns.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | |
lambda_return ¶
lambda_return(rewards: ndarray, values: ndarray, last_values: ndarray, dones: ndarray, gamma: float = 0.99, lambda_: float = 0.8, clip: bool = False) -> np.ndarray
λ-return :math:R^\lambda_t = r_t + \gamma((1-\lambda) V_{t+1} + \lambda R^\lambda_{t+1}).
With lambda_ == 1.0 collapses to the n-step return; with
lambda_ == 0.0 collapses to one-step TD.
Source code in rlib/training/returns.py
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | |
nstep_return ¶
nstep_return(rewards: ndarray, last_values: ndarray, dones: ndarray, gamma: float = 0.99, clip: bool = False) -> np.ndarray
N-step bootstrapped return :math:R_t = r_t + \gamma R_{t+1}.
The recursion is reset to zero whenever dones[t] is set so the
next episode's rewards don't bleed into the current one.
Source code in rlib/training/returns.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | |
make_validator ¶
make_validator(val_envs: list | BatchEnv | DummyBatchEnv) -> Validator
Pick the appropriate :class:Validator for the given env(s).
Source code in rlib/training/validation.py
183 184 185 186 187 | |