Skip to content

models

models

MLP

MLP(input_shape, hidden_size: int = 64, activation: type[Module] = torch.nn.Tanh)

Bases: Module

Two-layer MLP body for low-dimensional (vector) observations.

Accepts either a 1-D shape tuple (in_dim,) or an int in_dim so it slots into both the input_size and input_shape interfaces used across the agent-model wrappers.

Source code in rlib/models.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    input_shape,
    hidden_size: int = 64,
    activation: type[torch.nn.Module] = torch.nn.Tanh,
) -> None:
    super().__init__()
    in_dim = int(input_shape[0]) if hasattr(input_shape, "__len__") else int(input_shape)
    self.dense_size = hidden_size
    self.net = torch.nn.Sequential(
        torch.nn.Linear(in_dim, hidden_size),
        activation(),
        torch.nn.Linear(hidden_size, hidden_size),
        activation(),
    )

MaskedRNN

MaskedRNN(cell, time_major=True)

Bases: Module

dynamic masked hidden state RNN for sequences that reset part way through an observation e.g. A2C args : cell - a recurrent cell module (e.g. torch.nn.LSTMCell or torch.nn.GRUCell) X - tensor of rank [time, batch, hidden] if time major == True (Default); or [batch, time, hidden] if time major == False hidden_init - tensor of initial cell hidden state mask - boolean tensor of length time, used for hidden state masking e.g. [True, False, False] will mask the first hidden state time_major - bool flag to determine order of indices of input tensor

Source code in rlib/models.py
358
359
360
361
def __init__(self, cell, time_major=True):
    super().__init__()
    self.cell = cell
    self.time_major = time_major

forward

forward(x, hidden=None, mask=None)

Args: x: tensor of rank [time, batch, hidden] if time_major == True (default); or [batch, time, hidden] if time_major == False. mask: tensor of rank [time], for hidden state masking e.g. [True, False, False] will mask first hidden state.

Source code in rlib/models.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def forward(self, x, hidden=None, mask=None):
    '''Args:
    x: tensor of rank [time, batch, hidden] if time_major == True (default); or [batch, time, hidden] if time_major == False.
    mask: tensor of rank [time], for hidden state masking e.g. [True, False, False] will mask first hidden state.
    '''

    if not self.time_major:
        x = x.transpose(1, 0, 2)

    if mask is None:
        mask = torch.zeros(x.shape[0], x.shape[1]).to(x.device)

    outputs = []
    for t in range(x.shape[0]):
        output, hidden = self.cell(x[t], hidden, mask[t])
        outputs.append(output)

    outputs = torch.stack(outputs, dim=0)
    outputs = outputs if self.time_major else outputs.transpose(1, 0, 2)
    return outputs, hidden