Skip to content

Commit 4cecce0

Browse files
authored
Merge pull request #64 from HYPERNETS/writer_fill
fill nans on write
2 parents 3b96b91 + 9e1bb31 commit 4cecce0

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

hypernets_processor/data_io/hypernets_writer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
from hypernets_processor.version import __version__
66
import os
7+
import numpy as np
78
from datetime import datetime
89

10+
911
"""___Authorship___"""
1012
__author__ = "Sam Hunt"
1113
__created__ = "12/2/2020"
@@ -117,6 +119,8 @@ def _write_netcdf(ds, path, compression_level=None):
117119
var_encoding.update(ds[var_name].encoding)
118120
encoding.update({var_name: var_encoding})
119121

122+
ds = HypernetsWriter.fill_ds(ds)
123+
120124
ds.to_netcdf(path, format="netCDF4", engine="netcdf4", encoding=encoding)
121125

122126
@staticmethod
@@ -140,6 +144,24 @@ def _write_csv(ds, path):
140144
for meta_name in ds.attrs.keys():
141145
f.write(meta_name + ": " + ds.attrs[meta_name] + "\n")
142146

147+
@staticmethod
148+
def fill_ds(ds):
149+
"""
150+
Fill nan's in ds will fillValue
151+
152+
:type ds: xarray.Dataset
153+
:param ds: dataset
154+
155+
:return: filled data
156+
:rtype: xarray.Dataset
157+
"""
158+
159+
for variable in ds.variables.keys():
160+
idx = np.where(np.isnan(ds[variable]))
161+
ds[variable][idx] = ds[variable]._FillValue
162+
163+
return ds
164+
143165

144166
if __name__ == "__main__":
145167
pass

hypernets_processor/data_io/tests/test_hypernets_writer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import unittest
66
from unittest.mock import patch, MagicMock
7+
from hypernets_processor.data_io.dataset_util import DatasetUtil
78
from hypernets_processor.data_io.hypernets_writer import HypernetsWriter
89
from hypernets_processor.test.test_functions import setup_test_context
910
from hypernets_processor.version import __version__
1011
from xarray import Dataset
12+
import numpy as np
1113

1214

1315
'''___Authorship___'''
@@ -81,6 +83,19 @@ def test__write_netcdf(self):
8183

8284
ds.to_netcdf.assert_called_once_with(path, encoding={}, engine='netcdf4', format='netCDF4')
8385

86+
def test_fill_ds(self):
87+
ds = Dataset()
88+
ds["array_variable1"] = DatasetUtil.create_variable([7, 8], np.float32)
89+
ds["array_variable2"] = DatasetUtil.create_variable([7, 8], np.float32)
90+
91+
ds["array_variable1"][2, 3] = np.nan
92+
ds["array_variable2"][2, 3] = np.nan
93+
94+
HypernetsWriter.fill_ds(ds)
95+
96+
self.assertTrue(np.all(ds["array_variable1"] == 9.96921E36))
97+
self.assertTrue(np.all(ds["array_variable2"] == 9.96921E36))
98+
8499

85100
if __name__ == '__main__':
86101
unittest.main()

0 commit comments

Comments
 (0)