In class projection_MLP(nn.Module), the third layer has a dimension mismatch.
self.layer3 = nn.Sequential( nn.Linear(hidden_dim, out_dim), nn.BatchNorm1d(hidden_dim) )
This works with the default parameter values since hidden_dim == out_dim, but this breaks when the out_dim is changed.
In class
projection_MLP(nn.Module), the third layer has a dimension mismatch.self.layer3 = nn.Sequential( nn.Linear(hidden_dim, out_dim), nn.BatchNorm1d(hidden_dim) )This works with the default parameter values since hidden_dim == out_dim, but this breaks when the out_dim is changed.