PPO¶
PPO ¶
Proximal Policy Optimisation.
PPO ¶
PPO(model, input_shape, action_size, config: PPOConfig, *, value_coeff: float = 1.0, build_optimiser: bool = True, optim: type[Optimizer] = torch.optim.Adam, optim_args: dict | None = None, **model_args)
Bases: PPOModel
Single-critic PPO actor-critic.
Source code in rlib/PPO/model.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | |
PPOConfig
dataclass
¶
PPOConfig(lr: float = 0.001, lr_final: float = 0.0, decay_steps: int = 600000, grad_clip: float | None = 0.5, device: str = 'cuda', entropy_coeff: float = 0.01, policy_clip: float = 0.1)
Bases: ModelConfig
Hyperparameters for clipped-objective PPO-family agents (PPO / RND / DAAC policy).
Attributes:
| Name | Type | Description |
|---|---|---|
entropy_coeff |
float
|
Coefficient on the policy entropy bonus. |
policy_clip |
float
|
Clipping parameter for PPO's clipped objective. |
PPOModel ¶
PPOModel(action_size: int, config: PPOConfig)
Bases: Agent
PPO-family base class: defines the clipped-objective policy loss.
Concrete subclasses (single-critic PPO, twin-critic PPOIntrinsic,
DAAC's policy head, ...) only need to implement forward,
evaluate and backprop; they all share the same clipped
policy loss + entropy bonus via :meth:ppo_clipped_policy_loss.
Source code in rlib/PPO/model.py
34 35 36 37 38 | |
ppo_clipped_policy_loss ¶
ppo_clipped_policy_loss(policy, old_policy, action_onehot, advantage)
PPO clipped-objective policy loss + entropy bonus.
Returns the (policy_loss, entropy) pair so subclasses can
combine them with whatever value-function loss(es) and
coefficients their agent uses.
Source code in rlib/PPO/model.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | |
PPOTrainer ¶
PPOTrainer(envs, agent: PPO, val_envs, config: PPOTrainerConfig)
Bases: SyncMultiEnvTrainer
Trainer for the clipped-objective PPO model.
Source code in rlib/PPO/trainer.py
26 27 28 29 30 31 32 33 34 35 | |
PPOTrainerConfig
dataclass
¶
PPOTrainerConfig(train_mode: TrainMode = TrainMode.NSTEP, returns: Returns = Returns.NSTEP, total_steps: int = 50000000, nsteps: int = 5, gamma: float = 0.99, lambda_: float = 0.95, validate_freq: int = 1000000, num_val_episodes: int = 50, max_val_steps: int = 10000, log_dir: str = 'logs/', model_dir: str = 'models/', save_freq: int = 0, log_scalars: bool = True, update_target_freq: int = 0, render_freq: int = 0, num_epochs: int = 4, num_minibatches: int = 4)