Skip to content

Commit faf60e4

Browse files
committed
add string static filters
1 parent 603bfb4 commit faf60e4

1 file changed

Lines changed: 203 additions & 16 deletions

File tree

  • datafusion/physical-expr/src/expressions

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 203 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,161 @@ fn instantiate_static_filter(
177177
// Float primitive types (use ordered wrappers for Hash/Eq)
178178
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
179179
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
180+
DataType::Utf8 => Ok(Arc::new(Utf8StaticFilter::try_new(&in_array)?)),
181+
DataType::LargeUtf8 => Ok(Arc::new(LargeUtf8StaticFilter::try_new(&in_array)?)),
182+
DataType::Utf8View => Ok(Arc::new(Utf8ViewStaticFilter::try_new(&in_array)?)),
180183
_ => {
181184
/* fall through to generic implementation for unsupported types (Struct, etc.) */
182185
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
183186
}
184187
}
185188
}
186189

190+
macro_rules! string_static_filter {
191+
($Name:ident, $ArrayType:ty, $TryDowncast:ident $(::<$offset:ty>)?) => {
192+
struct $Name {
193+
null_count: usize,
194+
in_array: ArrayRef,
195+
state: RandomState,
196+
map: HashMap<usize, (), ()>,
197+
}
198+
199+
impl $Name {
200+
fn try_new(in_array: &ArrayRef) -> Result<Self> {
201+
let array = Arc::clone(in_array);
202+
let in_array = array.$TryDowncast$(::<$offset>)?().ok_or_else(|| {
203+
exec_datafusion_err!(
204+
"Failed to downcast an array to a '{}' array",
205+
stringify!($ArrayType)
206+
)
207+
})?;
208+
209+
let null_count = in_array.null_count();
210+
let state = RandomState::default();
211+
let mut map: HashMap<usize, (), ()> =
212+
HashMap::with_capacity_and_hasher(in_array.len() - null_count, ());
213+
214+
with_hashes([&array], &state, |hashes| -> Result<()> {
215+
let insert_value = |idx: usize| {
216+
let hash = hashes[idx];
217+
if let RawEntryMut::Vacant(v) = map
218+
.raw_entry_mut()
219+
.from_hash(hash, |x| in_array.value(*x) == in_array.value(idx))
220+
{
221+
v.insert_with_hasher(hash, idx, (), |x| hashes[*x]);
222+
}
223+
};
224+
225+
match in_array.nulls() {
226+
Some(nulls) => BitIndexIterator::new(
227+
nulls.validity(),
228+
nulls.offset(),
229+
nulls.len(),
230+
)
231+
.for_each(insert_value),
232+
None => (0..in_array.len()).for_each(insert_value),
233+
}
234+
235+
Ok(())
236+
})?;
237+
238+
Ok(Self {
239+
null_count,
240+
in_array: array,
241+
state,
242+
map,
243+
})
244+
}
245+
}
246+
247+
impl StaticFilter for $Name {
248+
fn null_count(&self) -> usize {
249+
self.null_count
250+
}
251+
252+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
253+
downcast_dictionary_array! {
254+
v => {
255+
let values_contains = self.contains(v.values().as_ref(), negated)?;
256+
let result = take(&values_contains, v.keys(), None)?;
257+
return Ok(downcast_array(result.as_ref()))
258+
}
259+
_ => {}
260+
}
261+
262+
let needle = v.$TryDowncast$(::<$offset>)?().ok_or_else(|| {
263+
exec_datafusion_err!(
264+
"Failed to downcast an array to a '{}' array",
265+
stringify!($ArrayType)
266+
)
267+
})?;
268+
let haystack = self
269+
.in_array
270+
.$TryDowncast$(::<$offset>)?()
271+
.ok_or_else(|| {
272+
exec_datafusion_err!(
273+
"Failed to downcast an array to a '{}' array",
274+
stringify!($ArrayType)
275+
)
276+
})?;
277+
278+
let haystack_has_nulls = self.null_count > 0;
279+
let needle_nulls = needle.nulls();
280+
let needle_has_nulls = needle.null_count() > 0;
281+
282+
let contains_buffer =
283+
with_hashes([v as &dyn Array], &self.state, |hashes| {
284+
Ok(BooleanBuffer::collect_bool(needle.len(), |i| {
285+
let contains = self
286+
.map
287+
.raw_entry()
288+
.from_hash(hashes[i], |idx| {
289+
needle.value(i) == haystack.value(*idx)
290+
})
291+
.is_some();
292+
contains ^ negated
293+
}))
294+
})?;
295+
296+
let result_nulls = match (needle_has_nulls, haystack_has_nulls) {
297+
(false, false) => None,
298+
(true, false) => needle_nulls.cloned(),
299+
(false, true) => {
300+
let validity = if negated {
301+
!&contains_buffer
302+
} else {
303+
contains_buffer.clone()
304+
};
305+
Some(NullBuffer::new(validity))
306+
}
307+
(true, true) => {
308+
let needle_validity = needle_nulls
309+
.map(|n| n.inner().clone())
310+
.unwrap_or_else(|| BooleanBuffer::new_set(v.len()));
311+
let haystack_validity = if negated {
312+
!&contains_buffer
313+
} else {
314+
contains_buffer.clone()
315+
};
316+
let combined_validity = &needle_validity & &haystack_validity;
317+
Some(NullBuffer::new(combined_validity))
318+
}
319+
};
320+
321+
Ok(BooleanArray::new(contains_buffer, result_nulls))
322+
}
323+
}
324+
};
325+
}
326+
327+
string_static_filter!(Utf8StaticFilter, StringArray, as_string_opt::<i32>);
328+
string_static_filter!(
329+
LargeUtf8StaticFilter,
330+
LargeStringArray,
331+
as_string_opt::<i64>
332+
);
333+
string_static_filter!(Utf8ViewStaticFilter, StringViewArray, as_string_view_opt);
334+
187335
impl ArrayStaticFilter {
188336
/// Computes a [`StaticFilter`] for the provided [`Array`] if there
189337
/// are nulls present or there are more than the configured number of
@@ -3962,24 +4110,40 @@ mod tests {
39624110
);
39634111
}
39644112

3965-
// Utf8 (falls through to ArrayStaticFilter)
3966-
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
3967-
let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;
4113+
let string_cases = vec![
4114+
(
4115+
"Utf8",
4116+
Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef,
4117+
Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef,
4118+
),
4119+
(
4120+
"LargeUtf8",
4121+
Arc::new(LargeStringArray::from(vec!["a", "b", "c"])) as ArrayRef,
4122+
Arc::new(LargeStringArray::from(vec!["a", "d", "b"])) as ArrayRef,
4123+
),
4124+
(
4125+
"Utf8View",
4126+
Arc::new(StringViewArray::from(vec!["a", "b", "c"])) as ArrayRef,
4127+
Arc::new(StringViewArray::from(vec!["a", "d", "b"])) as ArrayRef,
4128+
),
4129+
];
39684130

3969-
// Utf8 in_array, Utf8 needle
3970-
assert_eq!(
3971-
expected,
3972-
eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)?
3973-
);
4131+
for (name, in_array, needle) in string_cases {
4132+
assert_eq!(
4133+
expected,
4134+
eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array),)?,
4135+
"same-type failed for {name}"
4136+
);
39744137

3975-
// Utf8 in_array, Dict(Utf8) needle
3976-
assert_eq!(
3977-
expected,
3978-
eval_in_list_from_array(
3979-
wrap_in_dict(Arc::clone(&utf8_needle)),
3980-
Arc::clone(&utf8_in),
3981-
)?
3982-
);
4138+
assert_eq!(
4139+
expected,
4140+
eval_in_list_from_array(wrap_in_dict(needle), in_array)?,
4141+
"dict-needle failed for {name}"
4142+
);
4143+
}
4144+
4145+
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
4146+
let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;
39834147

39844148
// Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
39854149
assert_eq!(
@@ -4084,4 +4248,27 @@ mod tests {
40844248

40854249
Ok(())
40864250
}
4251+
4252+
#[test]
4253+
fn test_utf8_static_filter_avoids_string_copies() -> Result<()> {
4254+
let in_array = Arc::new(StringArray::from(vec![
4255+
Some("alpha"),
4256+
Some("beta"),
4257+
Some("alpha"),
4258+
None,
4259+
])) as ArrayRef;
4260+
4261+
let filter = Utf8StaticFilter::try_new(&in_array)?;
4262+
4263+
assert_eq!(filter.null_count(), 1);
4264+
assert_eq!(filter.map.len(), 2);
4265+
4266+
let needle = Arc::new(StringArray::from(vec![Some("alpha"), Some("gamma"), None]))
4267+
as ArrayRef;
4268+
4269+
let result = filter.contains(needle.as_ref(), false)?;
4270+
assert_eq!(result, BooleanArray::from(vec![Some(true), None, None]));
4271+
4272+
Ok(())
4273+
}
40874274
}

0 commit comments

Comments
 (0)