55import itertools
66import re
77import warnings
8- from collections import ChainMap
8+ from collections import ChainMap , namedtuple
99from datetime import datetime
1010from typing import (
1111 Any ,
5858 parse_cf_standard_name_table ,
5959)
6060
61+ FlagParam = namedtuple ("FlagParam" , ["flag_mask" , "flag_value" ])
62+
6163#: Classes wrapped by cf_xarray.
6264_WRAPPED_CLASSES = (Resample , GroupBy , Rolling , Coarsen , Weighted )
6365
@@ -1057,18 +1059,39 @@ def __getattr__(self, attr):
10571059 )
10581060
10591061
1060- def create_flag_dict (da ):
1062+ def create_flag_dict (da ) -> Mapping [Hashable , FlagParam ]:
1063+ """
1064+ Return possible flag meanings and associated bitmask/values.
1065+
1066+ The mapping values are a tuple containing a bitmask and a value. Either
1067+ can be None.
1068+ If only a bitmask: Independent flags.
1069+ If only a value: Mutually exclusive flags.
1070+ If both: Mix of independent and mutually exclusive flags.
1071+ """
10611072 if not da .cf .is_flag_variable :
10621073 raise ValueError (
1063- "Comparisons are only supported for DataArrays that represent CF flag variables."
1064- ".attrs must contain 'flag_values' and 'flag_meanings'"
1074+ "Comparisons are only supported for DataArrays that represent "
1075+ "CF flag variables. .attrs must contain 'flag_meanings' and "
1076+ "'flag_values' or 'flag_masks'."
10651077 )
10661078
10671079 flag_meanings = da .attrs ["flag_meanings" ].split (" " )
1068- flag_values = da .attrs ["flag_values" ]
1069- # TODO: assert flag_values is iterable
1070- assert len (flag_values ) == len (flag_meanings )
1071- return dict (zip (flag_meanings , flag_values ))
1080+ n_flag = len (flag_meanings )
1081+
1082+ flag_values = da .attrs .get ("flag_values" , [None ] * n_flag )
1083+ flag_masks = da .attrs .get ("flag_masks" , [None ] * n_flag )
1084+
1085+ if not (n_flag == len (flag_values ) == len (flag_masks )):
1086+ raise ValueError (
1087+ "Not as many flag meanings as values or masks. "
1088+ "Please check the flag_meanings, flag_values, flag_masks attributes "
1089+ )
1090+
1091+ flag_params = tuple (
1092+ FlagParam (mask , value ) for mask , value in zip (flag_masks , flag_values )
1093+ )
1094+ return dict (zip (flag_meanings , flag_params ))
10721095
10731096
10741097class CFAccessor :
@@ -1084,36 +1107,40 @@ def __setstate__(self, d):
10841107 self .__dict__ = d
10851108
10861109 def _assert_valid_other_comparison (self , other ):
1110+ # TODO cache this property
10871111 flag_dict = create_flag_dict (self ._obj )
10881112 if other not in flag_dict :
10891113 raise ValueError (
10901114 f"Did not find flag value meaning [{ other } ] in known flag meanings: [{ flag_dict .keys ()!r} ]"
10911115 )
1116+ if flag_dict [other ].flag_mask is not None :
1117+ raise NotImplementedError (
1118+ "Only equals and not-equals comparisons with flag masks are supported."
1119+ " Please open an issue."
1120+ )
10921121 return flag_dict
10931122
1094- def __eq__ (self , other ):
1123+ def __eq__ (self , other ) -> DataArray : # type: ignore
10951124 """
10961125 Compare flag values against `other`.
10971126
10981127 `other` must be in the 'flag_meanings' attribute.
10991128 `other` is mapped to the corresponding value in the 'flag_values' attribute, and then
11001129 compared.
11011130 """
1102- flag_dict = self ._assert_valid_other_comparison (other )
1103- return self ._obj == flag_dict [other ]
1131+ return self ._extract_flags ([other ])[other ].rename (self ._obj .name )
11041132
1105- def __ne__ (self , other ):
1133+ def __ne__ (self , other ) -> DataArray : # type: ignore
11061134 """
11071135 Compare flag values against `other`.
11081136
11091137 `other` must be in the 'flag_meanings' attribute.
11101138 `other` is mapped to the corresponding value in the 'flag_values' attribute, and then
11111139 compared.
11121140 """
1113- flag_dict = self ._assert_valid_other_comparison (other )
1114- return self ._obj != flag_dict [other ]
1141+ return ~ self ._extract_flags ([other ])[other ].rename (self ._obj .name )
11151142
1116- def __lt__ (self , other ):
1143+ def __lt__ (self , other ) -> DataArray :
11171144 """
11181145 Compare flag values against `other`.
11191146
@@ -1122,9 +1149,9 @@ def __lt__(self, other):
11221149 compared.
11231150 """
11241151 flag_dict = self ._assert_valid_other_comparison (other )
1125- return self ._obj < flag_dict [other ]
1152+ return self ._obj < flag_dict [other ]. flag_value
11261153
1127- def __le__ (self , other ):
1154+ def __le__ (self , other ) -> DataArray :
11281155 """
11291156 Compare flag values against `other`.
11301157
@@ -1133,9 +1160,9 @@ def __le__(self, other):
11331160 compared.
11341161 """
11351162 flag_dict = self ._assert_valid_other_comparison (other )
1136- return self ._obj <= flag_dict [other ]
1163+ return self ._obj <= flag_dict [other ]. flag_value
11371164
1138- def __gt__ (self , other ):
1165+ def __gt__ (self , other ) -> DataArray :
11391166 """
11401167 Compare flag values against `other`.
11411168
@@ -1144,9 +1171,9 @@ def __gt__(self, other):
11441171 compared.
11451172 """
11461173 flag_dict = self ._assert_valid_other_comparison (other )
1147- return self ._obj > flag_dict [other ]
1174+ return self ._obj > flag_dict [other ]. flag_value
11481175
1149- def __ge__ (self , other ):
1176+ def __ge__ (self , other ) -> DataArray :
11501177 """
11511178 Compare flag values against `other`.
11521179
@@ -1155,9 +1182,9 @@ def __ge__(self, other):
11551182 compared.
11561183 """
11571184 flag_dict = self ._assert_valid_other_comparison (other )
1158- return self ._obj >= flag_dict [other ]
1185+ return self ._obj >= flag_dict [other ]. flag_value
11591186
1160- def isin (self , test_elements ):
1187+ def isin (self , test_elements ) -> DataArray :
11611188 """Test each value in the array for whether it is in test_elements.
11621189
11631190 Parameters
@@ -1177,14 +1204,15 @@ def isin(self, test_elements):
11771204 raise ValueError (
11781205 ".cf.isin is only supported on DataArrays that contain CF flag attributes."
11791206 )
1207+ # TODO cache this property
11801208 flag_dict = create_flag_dict (self ._obj )
11811209 mapped_test_elements = []
11821210 for elem in test_elements :
11831211 if elem not in flag_dict :
11841212 raise ValueError (
11851213 f"Did not find flag value meaning [{ elem } ] in known flag meanings: [{ flag_dict .keys ()!r} ]"
11861214 )
1187- mapped_test_elements .append (flag_dict [elem ])
1215+ mapped_test_elements .append (flag_dict [elem ]. flag_value )
11881216 return self ._obj .isin (mapped_test_elements )
11891217
11901218 def _drop_missing_variables (self , variables : list [Hashable ]) -> list [Hashable ]:
@@ -2753,22 +2781,104 @@ def __getitem__(self, key: Hashable | Iterable[Hashable]) -> DataArray:
27532781
27542782 return _getitem (self , key )
27552783
2784+ @property
2785+ def flags (self ) -> Dataset :
2786+ """
2787+ Dataset containing boolean masks of available flags.
2788+ """
2789+ return self ._extract_flags ()
2790+
2791+ def _extract_flags (self , flags : Sequence [Hashable ] | None = None ) -> Dataset :
2792+ """
2793+ Return dataset of boolean mask(s) corresponding to `flags`.
2794+
2795+ Parameters
2796+ ----------
2797+ flags: Sequence[str]
2798+ Flags to extract. If empty (string or list), return all flags in
2799+ `flag_meanings`.
2800+ """
2801+ # TODO cache this property
2802+ flag_dict = create_flag_dict (self ._obj )
2803+
2804+ if flags is None :
2805+ flags = tuple (flag_dict .keys ())
2806+
2807+ out = {} # Output arrays
2808+
2809+ masks = [] # Bitmasks and values for asked flags
2810+ values = []
2811+ flags_reduced = [] # Flags left after removing mutually excl. flags
2812+ for flag in flags :
2813+ if flag not in flag_dict :
2814+ raise ValueError (
2815+ f"Did not find flag value meaning [{ flag } ] in known flag meanings:"
2816+ f" [{ flag_dict .keys ()!r} ]"
2817+ )
2818+ mask , value = flag_dict [flag ]
2819+ if mask is None :
2820+ out [flag ] = self ._obj == value
2821+ else :
2822+ masks .append (mask )
2823+ values .append (value )
2824+ flags_reduced .append (flag )
2825+
2826+ if len (masks ) > 0 : # If independant masks are left
2827+ # We cast both masks and flag variable as integers to make the
2828+ # bitwise comparison. We could probably restrict the integer size
2829+ # but it's difficult to make it safely for mixed type flags.
2830+ bit_mask = DataArray (masks , dims = ["_mask" ]).astype ("i" )
2831+ x = self ._obj .astype ("i" )
2832+ bit_comp = x & bit_mask
2833+
2834+ for i , (flag , value ) in enumerate (zip (flags_reduced , values )):
2835+ bit = bit_comp .isel (_mask = i )
2836+ if value is not None :
2837+ out [flag ] = bit == value
2838+ else :
2839+ out [flag ] = bit .astype (bool )
2840+
2841+ return Dataset (out )
2842+
2843+ def isin (self , test_elements ):
2844+ """
2845+ Test each value in the array for whether it is in test_elements.
2846+
2847+ Parameters
2848+ ----------
2849+ test_elements : array_like, 1D
2850+ The values against which to test each value of `element`.
2851+
2852+ Returns
2853+ -------
2854+ isin : DataArray
2855+ Has the same type and shape as this object, but with a bool dtype.
2856+ """
2857+ flags_masks = self .flags .drop_vars (
2858+ [v for v in self .flags .data_vars if v not in test_elements ]
2859+ )
2860+ if len (flags_masks ) == 0 :
2861+ out = self .copy ().astype (bool )
2862+ out .attrs = {}
2863+ out [:] = False
2864+ return out
2865+ # Merge into a single DataArray
2866+ flags_masks = xr .concat (flags_masks .data_vars .values (), dim = "_flags" )
2867+ return flags_masks .any (dim = "_flags" ).rename (self ._obj .name )
2868+
27562869 @property
27572870 def is_flag_variable (self ) -> bool :
27582871 """
27592872 Returns True if the DataArray satisfies CF conventions for flag variables.
27602873
2761- .. warning::
2762- Flag masks are not supported yet.
2763-
27642874 Returns
27652875 -------
27662876 bool
27672877 """
27682878 if (
27692879 isinstance (self ._obj , DataArray )
27702880 and "flag_meanings" in self ._obj .attrs
2771- and "flag_values" in self ._obj .attrs
2881+ and ( "flag_values" in self ._obj .attrs or "flag_masks" in self . _obj . attrs )
27722882 ):
27732883 return True
27742884 else :
0 commit comments