Agent base¶
agent ¶
Common abstract base class for trainable agents in rlib.
Almost every agent in :mod:rlib reimplements the same boilerplate:
- store the same set of LR / grad-clip hyperparameters,
- build a polynomial-decay LR scheduler on top of an optimiser, and
- run an identical
loss.backward → clip_grad_norm_ → optimiser.step → zero_grad → scheduler.stepsequence at the end of everybackpropcall.
:class:Agent factors that out into a single abstract base so agent
implementations can focus on what is actually agent-specific (their
forward pass, evaluate signature, loss formulation, and backprop
signature).
The base class deliberately does not know about any specific RL
algorithm. Algorithm-specific loss functions live on per-algorithm
subclasses (e.g. :class:rlib.A2C.A2CModel) which their concrete
variants (feed-forward, recurrent, ...) can inherit and reuse.
Subclasses are expected to:
- Call
super().__init__(config=...)with a :class:ModelConfig(or subclass). - Build their network heads (policy/value/Q/etc.) attached to
self.device. - Optionally call :meth:
Agent._build_optimiseronce they have all their parameters in place. Composite agents that delegate optimisation to a child can simply skip this step. - Implement
forward(inherited from :class:torch.nn.Module) and the abstract :meth:backpropand :meth:evaluatemethods. - End each
backpropmethod withreturn self._train_step(loss).
Per-agent ModelConfig subclasses (A2CConfig, PPOConfig, ...)
live next to their respective agent class in rlib/<Agent>/model.py.
ModelConfig
dataclass
¶
ModelConfig(lr: float = 0.001, lr_final: float = 0.0, decay_steps: int = 600000, grad_clip: float | None = 0.5, device: str = 'cuda')
Shared hyperparameters for every :class:Agent.
Attributes:
| Name | Type | Description |
|---|---|---|
lr |
float
|
Initial learning rate. |
lr_final |
float
|
Final learning rate the polynomial scheduler decays to. |
decay_steps |
int
|
Optimiser steps over which the LR decays from
|
grad_clip |
float | None
|
Maximum gradient norm; |
device |
str
|
Torch device string ( |
Agent ¶
Agent(config: ModelConfig)
Bases: Module, ABC
Abstract base class for trainable rlib agent models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig
|
A :class: |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ModelConfig
|
The original config object (immutable). |
lr, |
(lr_final, decay_steps, grad_clip, device)
|
Convenience
mirrors of the corresponding |
Source code in rlib/agent.py
90 91 92 93 94 95 96 97 98 99 | |
evaluate
abstractmethod
¶
evaluate(*args: Any, **kwargs: Any) -> Any
Numpy-in / numpy-out inference call used by agent rollouts.
The exact signature is agent-specific (feed-forward agents take
a single observation array, recurrent agents also take a
hidden-state tuple, etc.), so each concrete subclass declares
its own. Implementations should run :meth:forward under
torch.no_grad() and convert torch tensors back to numpy.
Source code in rlib/agent.py
108 109 110 111 112 113 114 115 116 117 | |
backprop
abstractmethod
¶
backprop(*args: Any, **kwargs: Any) -> np.ndarray
Run a full training step (numpy in / numpy loss-scalar out).
Implementations typically: convert numpy inputs to tensors,
call :meth:forward, compute the agent-specific loss, then
return self._train_step(loss).
Source code in rlib/agent.py
119 120 121 122 123 124 125 126 | |
value_loss
staticmethod
¶
value_loss(R: Tensor, V: Tensor) -> torch.Tensor
Half mean-squared-error between targets R and values V.
Used by every value-based agent in this library (A2C critic,
PPO critic, DQN, Q-aux, ...). Algorithm-specific losses (A2C
actor-critic loss, PPO clipped objective, ...) belong on the
relevant per-algorithm :class:Model subclass.
Source code in rlib/agent.py
153 154 155 156 157 158 159 160 161 162 | |