Since the model is receiving input of [batch, seq, vars, lat (96), lon(144)]
it should be lat_size = y.shape[-2]
This is the case if channels_last is set to False.
def LLWeighted_RMSE_WheatherBench(preds: np.ndarray, y: np.ndarray):
"""
Weigthed RMSE taken from Wheather Bench.
Weighting to account for decreasing grid sizes towards the pole.
rmse = mean over forecasts and time of np.sqrt( mean over lon lat L(lat_j)*)MSE(preds, y)
weights = cos(latitude)/cos(latitude).mean()
"""
lat_size = y.shape[-1]
lats = np.linspace(-90, 90, lat_size)
weights = (np.cos(lats) / np.cos(lats)).mean()
rmse = np.sqrt(np.mean(weights * ((preds - y) ** 2), axis=(-1, -2))).mean()
return rmse