Skip to content

Commit a0bd6fa

Browse files
committed
add flag handling methods
1 parent 200c3c1 commit a0bd6fa

File tree

2 files changed

+301
-2
lines changed

2 files changed

+301
-2
lines changed

hypernets_processor/data_io/dataset_util.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from hypernets_processor.version import __version__
66
import string
7-
from xarray import Variable, DataArray
7+
from xarray import Variable, DataArray, Dataset
88
import numpy as np
99

1010

@@ -220,6 +220,152 @@ def get_default_fill_value(dtype):
220220
elif dtype == np.float64:
221221
return np.float64(9.969209968386869E36)
222222

223+
@staticmethod
224+
def _get_flag_encoding(da):
225+
"""
226+
Returns flag encoding for flag type data array
227+
228+
:type da: xarray.DataArray
229+
:param da: data array
230+
231+
:return: flag meanings
232+
:rtype: list
233+
234+
:return: flag masks
235+
:rtype: list
236+
"""
237+
238+
try:
239+
flag_meanings = da.attrs["flag_meanings"].split()
240+
flag_masks = [int(fm) for fm in da.attrs["flag_masks"].split(",")]
241+
except KeyError:
242+
raise KeyError(da.name + " not a flag variable")
243+
244+
return flag_meanings, flag_masks
245+
246+
@staticmethod
247+
def unpack_flags(da):
248+
"""
249+
Breaks down flag data array into dataset of boolean masks for each flag
250+
251+
:type da: xarray.DataArray
252+
:param da: dataset
253+
254+
:return: flag masks
255+
:rtype: xarray.Dataset
256+
"""
257+
258+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
259+
260+
ds = Dataset()
261+
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
262+
ds[flag_meaning] = DatasetUtil.create_variable(list(da.shape), bool, dim_names=list(da.dims))
263+
ds[flag_meaning] = (da & flag_mask).astype(bool)
264+
265+
return ds
266+
267+
@staticmethod
268+
def set_flag(da, flag_name, error_if_set=False):
269+
"""
270+
Sets named flag for elements in data array
271+
272+
:type da: xarray.DataArray
273+
:param da: dataset
274+
275+
:type flag_name: str
276+
:param flag_name: name of flag to set
277+
278+
:type error_if_set: bool
279+
:param error_if_set: raises error if chosen flag is already set for any element
280+
"""
281+
282+
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
283+
284+
if np.any(set_flags == True) and error_if_set:
285+
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
286+
287+
# Find flag mask
288+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
289+
flag_bit = flag_meanings.index(flag_name)
290+
flag_mask = flag_masks[flag_bit]
291+
292+
return da | flag_mask
293+
294+
@staticmethod
295+
def unset_flag(da, flag_name, error_if_unset=False):
296+
"""
297+
Unsets named flag for specified index of dataset variable
298+
299+
:type da: xarray.DataArray
300+
:param da: data array
301+
302+
:type flag_name: str
303+
:param flag_name: name of flag to unset
304+
305+
:type error_if_unset: bool
306+
:param error_if_unset: raises error if chosen flag is already set at specified index
307+
"""
308+
309+
set_flags = DatasetUtil.unpack_flags(da)[flag_name]
310+
311+
if np.any(set_flags == False) and error_if_unset:
312+
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)
313+
314+
# Find flag mask
315+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
316+
flag_bit = flag_meanings.index(flag_name)
317+
flag_mask = flag_masks[flag_bit]
318+
319+
return da & ~flag_mask
320+
321+
@staticmethod
322+
def get_set_flags(da):
323+
"""
324+
Return list of set flags for single element data array
325+
326+
:type da: xarray.DataArray
327+
:param da: single element data array
328+
329+
:return: set flags
330+
:rtype: list
331+
"""
332+
333+
if da.shape != ():
334+
raise ValueError("Must pass single element data array")
335+
336+
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
337+
338+
set_flags = []
339+
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
340+
if (da & flag_mask):
341+
set_flags.append(flag_meaning)
342+
343+
return set_flags
344+
345+
@staticmethod
346+
def check_flag_set(da, flag_name):
347+
"""
348+
Returns if flag for single element data array
349+
350+
:type da: xarray.DataArray
351+
:param da: single element data array
352+
353+
:type flag_name: str
354+
:param flag_name: name of flag to set
355+
356+
:return: set flags
357+
:rtype: list
358+
"""
359+
360+
if da.shape != ():
361+
raise ValueError("Must pass single element data array")
362+
363+
set_flags = DatasetUtil.get_set_flags(da)
364+
365+
if flag_name in set_flags:
366+
return True
367+
return False
368+
223369

224370
if __name__ == "__main__":
225371
pass

hypernets_processor/data_io/tests/test_dataset_util.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import unittest
66
import numpy as np
7-
from xarray import DataArray, Variable
7+
from xarray import DataArray, Variable, Dataset
88
from hypernets_processor.data_io.dataset_util import DatasetUtil
99
from hypernets_processor.version import __version__
1010

@@ -241,6 +241,159 @@ def test_get_default_fill_value(self):
241241
self.assertEqual(np.float32(9.96921E36), DatasetUtil.get_default_fill_value(np.float32))
242242
self.assertEqual(9.969209968386869E36, DatasetUtil.get_default_fill_value(np.float64))
243243

244+
def test__get_flag_encoding(self):
245+
246+
ds = Dataset()
247+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
248+
masks = [1, 2, 4, 8, 16, 32, 64, 128]
249+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
250+
attributes={"standard_name": "std"})
251+
252+
ds["flags"] = flags_vector_variable
253+
254+
meanings_out, masks_out = DatasetUtil._get_flag_encoding(ds["flags"])
255+
256+
self.assertCountEqual(meanings, meanings_out)
257+
self.assertCountEqual(masks, masks_out)
258+
259+
def test__get_flag_encoding_not_flag_var(self):
260+
ds = Dataset()
261+
ds["array_variable"] = DatasetUtil.create_variable([7, 8, 3], np.int8, attributes={"standard_name": "std"})
262+
263+
self.assertRaises(KeyError, DatasetUtil._get_flag_encoding, ds["array_variable"])
264+
265+
def test_unpack_flags(self):
266+
267+
ds = Dataset()
268+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
269+
masks = [1, 2, 4, 8, 16, 32, 64, 128]
270+
flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"],
271+
attributes={"standard_name": "std"})
272+
273+
ds["flags"] = flags_vector_variable
274+
ds["flags"][0, 0] = ds["flags"][0, 0] | 8
275+
276+
empty = np.zeros((2, 3), bool)
277+
flag4 = np.zeros((2, 3), bool)
278+
flag4[0,0] = True
279+
280+
flags = DatasetUtil.unpack_flags(ds["flags"])
281+
282+
self.assertTrue((flags["flag1"].data == empty).all())
283+
self.assertTrue((flags["flag2"].data == empty).all())
284+
self.assertTrue((flags["flag3"].data == empty).all())
285+
self.assertTrue((flags["flag4"].data == flag4).all())
286+
self.assertTrue((flags["flag5"].data == empty).all())
287+
self.assertTrue((flags["flag6"].data == empty).all())
288+
self.assertTrue((flags["flag7"].data == empty).all())
289+
self.assertTrue((flags["flag8"].data == empty).all())
290+
291+
def test_get_set_flags(self):
292+
293+
ds = Dataset()
294+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
295+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
296+
attributes={"standard_name": "std"})
297+
ds["flags"] = flags_vector_variable
298+
ds["flags"][3] = ds["flags"][3] | 8
299+
ds["flags"][3] = ds["flags"][3] | 32
300+
301+
set_flags = DatasetUtil.get_set_flags(ds["flags"][3])
302+
303+
self.assertCountEqual(set_flags, ["flag4", "flag6"])
304+
305+
def test_get_set_flags_2d(self):
306+
307+
ds = Dataset()
308+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
309+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
310+
attributes={"standard_name": "std"})
311+
ds["flags"] = flags_vector_variable
312+
313+
self.assertRaises(ValueError, DatasetUtil.get_set_flags, ds["flags"])
314+
315+
def test_check_flag_set_true(self):
316+
ds = Dataset()
317+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
318+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
319+
attributes={"standard_name": "std"})
320+
ds["flags"] = flags_vector_variable
321+
ds["flags"][3] = ds["flags"][3] | 8
322+
ds["flags"][3] = ds["flags"][3] | 32
323+
324+
flag_set = DatasetUtil.check_flag_set(ds["flags"][3], "flag6")
325+
326+
self.assertTrue(flag_set)
327+
328+
def test_check_flag_set_false(self):
329+
ds = Dataset()
330+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
331+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
332+
attributes={"standard_name": "std"})
333+
ds["flags"] = flags_vector_variable
334+
ds["flags"][3] = ds["flags"][3] | 8
335+
336+
flag_set = DatasetUtil.check_flag_set(ds["flags"][3], "flag6")
337+
338+
self.assertFalse(flag_set)
339+
340+
def test_check_flag_set_2d(self):
341+
ds = Dataset()
342+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
343+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
344+
attributes={"standard_name": "std"})
345+
ds["flags"] = flags_vector_variable
346+
347+
self.assertRaises(ValueError, DatasetUtil.check_flag_set, ds["flags"], "flag6")
348+
349+
def test_set_flag(self):
350+
351+
ds = Dataset()
352+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
353+
flags_vector_variable = DatasetUtil.create_flags_variable([5, 4], meanings, dim_names=["dim1", "dim2"],
354+
attributes={"standard_name": "std"})
355+
ds["flags"] = flags_vector_variable
356+
357+
ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")
358+
359+
flags = np.full(ds["flags"].shape, 0|8)
360+
361+
self.assertTrue((ds["flags"].data == flags).all())
362+
363+
def test_set_flag_error_if_set(self):
364+
ds = Dataset()
365+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
366+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
367+
attributes={"standard_name": "std"})
368+
ds["flags"] = flags_vector_variable
369+
ds["flags"][3] = ds["flags"][3] | 8
370+
371+
self.assertRaises(ValueError, DatasetUtil.set_flag, ds["flags"], "flag4", error_if_set=True)
372+
373+
def test_unset_flag(self):
374+
375+
ds = Dataset()
376+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
377+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
378+
attributes={"standard_name": "std"})
379+
ds["flags"] = flags_vector_variable
380+
ds["flags"][:] = ds["flags"][:] | 8
381+
382+
ds["flags"] = DatasetUtil.unset_flag(ds["flags"], "flag4")
383+
384+
flags = np.zeros(ds["flags"].shape)
385+
386+
self.assertTrue((ds["flags"].data == flags).all())
387+
388+
def test_set_flag_error_if_unset(self):
389+
ds = Dataset()
390+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
391+
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
392+
attributes={"standard_name": "std"})
393+
ds["flags"] = flags_vector_variable
394+
395+
self.assertRaises(ValueError, DatasetUtil.unset_flag, ds["flags"], "flag4", error_if_unset=True)
396+
244397

245398
if __name__ == '__main__':
246399
unittest.main()

0 commit comments

Comments
 (0)