-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
105 lines (81 loc) · 3.2 KB
/
utils.py
File metadata and controls
105 lines (81 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
from glob import glob
import gcsfs
from tqdm.notebook import tqdm
def make_dir(path):
if os.path.exists(path) is False:
os.makedirs(path)
def open_dataset(file_path):
"""Flexible opener that can handle both local files (legacy) and cloud urls. IMPORTANT: For this to work the `file_path` must be provided without extension."""
if 'gs://' in file_path:
store = f"{file_path}.zarr"
ds = xr.open_dataset(store, engine='zarr')
else:
ds = xr.open_dataset(f"{file_path}.nc")
# add information to sort and label etc
ds.attrs['file_name']
return ds
def prepare_predictor(data_sets, data_path,time_reindex=True):
"""
Args:
data_sets list(str): names of datasets
"""
# Create training and testing arrays
if isinstance(data_sets, str):
data_sets = [data_sets]
X_all = []
length_all = []
for file in tqdm(data_sets):
data = open_dataset(f"{data_path}inputs_{file}")
X_all.append(data)
length_all.append(len(data.time))
X = xr.concat(X_all,dim='time')
length_all = np.array(length_all)
# X = xr.concat([xr.open_dataset(data_path + f"inputs_{file}.nc") for file in data_sets], dim='time')
if time_reindex:
X = X.assign_coords(time=np.arange(len(X.time)))
return X, length_all
def prepare_predictand(data_sets,data_path,time_reindex=True):
if isinstance(data_sets, str):
data_sets = [data_sets]
Y_all = []
length_all = []
for file in tqdm(data_sets):
data = open_dataset(f"{data_path}outputs_{file}")
Y_all.append(data)
length_all.append(len(data.time))
length_all = np.array(length_all)
Y = xr.concat(Y_all,dim='time').mean('member')
# Y = xr.concat([xr.open_dataset(data_path + f"outputs_{file}.nc") for file in data_sets], dim='time').mean("member")
Y = Y.rename({'lon':'longitude','lat': 'latitude'}).transpose('time','latitude', 'longitude').drop(['quantile'])
if time_reindex:
Y = Y.assign_coords(time=np.arange(len(Y.time)))
return Y, length_all
def get_rmse(truth, pred):
weights = np.cos(np.deg2rad(truth.latitude))
return np.sqrt(((truth-pred)**2).weighted(weights).mean(['latitude', 'longitude'])).data.mean()
def plot_history(history):
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean squared error')
plt.plot(history.epoch, np.array(history.history['loss']),
label='Train Loss')
plt.plot(history.epoch, np.array(history.history['val_loss']),
label = 'Val loss')
plt.legend()
# Utilities for normalizing the input data
def normalize(data, var, meanstd_dict):
mean = meanstd_dict[var][0]
std = meanstd_dict[var][1]
return (data - mean)/std
def mean_std_plot(data,color,label,ax):
mean = data.mean(['latitude','longitude'])
std = data.std(['latitude','longitude'])
yr = data.time.values
ax.plot(yr,mean,color=color,label=label,linewidth=4)
ax.fill_between(yr,mean+std,mean-std,facecolor=color,alpha=0.4)
return yr, mean