models¶
models ¶
MLP ¶
MLP(input_shape, hidden_size: int = 64, activation: type[Module] = torch.nn.Tanh)
Bases: Module
Two-layer MLP body for low-dimensional (vector) observations.
Accepts either a 1-D shape tuple (in_dim,) or an int in_dim
so it slots into both the input_size and input_shape
interfaces used across the agent-model wrappers.
Source code in rlib/models.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | |
MaskedRNN ¶
MaskedRNN(cell, time_major=True)
Bases: Module
dynamic masked hidden state RNN for sequences that reset part way through an observation
e.g. A2C
args :
cell - a recurrent cell module (e.g. torch.nn.LSTMCell or torch.nn.GRUCell)
X - tensor of rank [time, batch, hidden] if time major == True (Default); or [batch, time, hidden] if time major == False
hidden_init - tensor of initial cell hidden state
mask - boolean tensor of length time, used for hidden state masking e.g. [True, False, False] will mask the first hidden state
time_major - bool flag to determine order of indices of input tensor
Source code in rlib/models.py
358 359 360 361 | |
forward ¶
forward(x, hidden=None, mask=None)
Args: x: tensor of rank [time, batch, hidden] if time_major == True (default); or [batch, time, hidden] if time_major == False. mask: tensor of rank [time], for hidden state masking e.g. [True, False, False] will mask first hidden state.
Source code in rlib/models.py
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 | |