diff --git a/pufferlib/models.py b/pufferlib/models.py index fa43d7071..3b9fc6363 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -111,7 +111,9 @@ def __init__(self, env, policy, input_size=128, hidden_size=128): self.hidden_size = hidden_size self.is_continuous = self.policy.is_continuous - for name, param in self.named_parameters(): + self.lstm = nn.LSTM(input_size, hidden_size) + + for name, param in self.lstm.named_parameters(): if 'layer_norm' in name: continue if "bias" in name: @@ -119,8 +121,6 @@ def __init__(self, env, policy, input_size=128, hidden_size=128): elif "weight" in name and param.ndim >= 2: nn.init.orthogonal_(param, 1.0) - self.lstm = nn.LSTM(input_size, hidden_size) - self.cell = torch.nn.LSTMCell(input_size, hidden_size) self.cell.weight_ih = self.lstm.weight_ih_l0 self.cell.weight_hh = self.lstm.weight_hh_l0