@@ -177,13 +177,161 @@ fn instantiate_static_filter(
177177 // Float primitive types (use ordered wrappers for Hash/Eq)
178178 DataType :: Float32 => Ok ( Arc :: new ( Float32StaticFilter :: try_new ( & in_array) ?) ) ,
179179 DataType :: Float64 => Ok ( Arc :: new ( Float64StaticFilter :: try_new ( & in_array) ?) ) ,
180+ DataType :: Utf8 => Ok ( Arc :: new ( Utf8StaticFilter :: try_new ( & in_array) ?) ) ,
181+ DataType :: LargeUtf8 => Ok ( Arc :: new ( LargeUtf8StaticFilter :: try_new ( & in_array) ?) ) ,
182+ DataType :: Utf8View => Ok ( Arc :: new ( Utf8ViewStaticFilter :: try_new ( & in_array) ?) ) ,
180183 _ => {
181184 /* fall through to generic implementation for unsupported types (Struct, etc.) */
182185 Ok ( Arc :: new ( ArrayStaticFilter :: try_new ( in_array) ?) )
183186 }
184187 }
185188}
186189
190+ macro_rules! string_static_filter {
191+ ( $Name: ident, $ArrayType: ty, $TryDowncast: ident $( :: <$offset: ty>) ?) => {
192+ struct $Name {
193+ null_count: usize ,
194+ in_array: ArrayRef ,
195+ state: RandomState ,
196+ map: HashMap <usize , ( ) , ( ) >,
197+ }
198+
199+ impl $Name {
200+ fn try_new( in_array: & ArrayRef ) -> Result <Self > {
201+ let array = Arc :: clone( in_array) ;
202+ let in_array = array. $TryDowncast$( :: <$offset>) ?( ) . ok_or_else( || {
203+ exec_datafusion_err!(
204+ "Failed to downcast an array to a '{}' array" ,
205+ stringify!( $ArrayType)
206+ )
207+ } ) ?;
208+
209+ let null_count = in_array. null_count( ) ;
210+ let state = RandomState :: default ( ) ;
211+ let mut map: HashMap <usize , ( ) , ( ) > =
212+ HashMap :: with_capacity_and_hasher( in_array. len( ) - null_count, ( ) ) ;
213+
214+ with_hashes( [ & array] , & state, |hashes| -> Result <( ) > {
215+ let insert_value = |idx: usize | {
216+ let hash = hashes[ idx] ;
217+ if let RawEntryMut :: Vacant ( v) = map
218+ . raw_entry_mut( )
219+ . from_hash( hash, |x| in_array. value( * x) == in_array. value( idx) )
220+ {
221+ v. insert_with_hasher( hash, idx, ( ) , |x| hashes[ * x] ) ;
222+ }
223+ } ;
224+
225+ match in_array. nulls( ) {
226+ Some ( nulls) => BitIndexIterator :: new(
227+ nulls. validity( ) ,
228+ nulls. offset( ) ,
229+ nulls. len( ) ,
230+ )
231+ . for_each( insert_value) ,
232+ None => ( 0 ..in_array. len( ) ) . for_each( insert_value) ,
233+ }
234+
235+ Ok ( ( ) )
236+ } ) ?;
237+
238+ Ok ( Self {
239+ null_count,
240+ in_array: array,
241+ state,
242+ map,
243+ } )
244+ }
245+ }
246+
247+ impl StaticFilter for $Name {
248+ fn null_count( & self ) -> usize {
249+ self . null_count
250+ }
251+
252+ fn contains( & self , v: & dyn Array , negated: bool ) -> Result <BooleanArray > {
253+ downcast_dictionary_array! {
254+ v => {
255+ let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
256+ let result = take( & values_contains, v. keys( ) , None ) ?;
257+ return Ok ( downcast_array( result. as_ref( ) ) )
258+ }
259+ _ => { }
260+ }
261+
262+ let needle = v. $TryDowncast$( :: <$offset>) ?( ) . ok_or_else( || {
263+ exec_datafusion_err!(
264+ "Failed to downcast an array to a '{}' array" ,
265+ stringify!( $ArrayType)
266+ )
267+ } ) ?;
268+ let haystack = self
269+ . in_array
270+ . $TryDowncast$( :: <$offset>) ?( )
271+ . ok_or_else( || {
272+ exec_datafusion_err!(
273+ "Failed to downcast an array to a '{}' array" ,
274+ stringify!( $ArrayType)
275+ )
276+ } ) ?;
277+
278+ let haystack_has_nulls = self . null_count > 0 ;
279+ let needle_nulls = needle. nulls( ) ;
280+ let needle_has_nulls = needle. null_count( ) > 0 ;
281+
282+ let contains_buffer =
283+ with_hashes( [ v as & dyn Array ] , & self . state, |hashes| {
284+ Ok ( BooleanBuffer :: collect_bool( needle. len( ) , |i| {
285+ let contains = self
286+ . map
287+ . raw_entry( )
288+ . from_hash( hashes[ i] , |idx| {
289+ needle. value( i) == haystack. value( * idx)
290+ } )
291+ . is_some( ) ;
292+ contains ^ negated
293+ } ) )
294+ } ) ?;
295+
296+ let result_nulls = match ( needle_has_nulls, haystack_has_nulls) {
297+ ( false , false ) => None ,
298+ ( true , false ) => needle_nulls. cloned( ) ,
299+ ( false , true ) => {
300+ let validity = if negated {
301+ !& contains_buffer
302+ } else {
303+ contains_buffer. clone( )
304+ } ;
305+ Some ( NullBuffer :: new( validity) )
306+ }
307+ ( true , true ) => {
308+ let needle_validity = needle_nulls
309+ . map( |n| n. inner( ) . clone( ) )
310+ . unwrap_or_else( || BooleanBuffer :: new_set( v. len( ) ) ) ;
311+ let haystack_validity = if negated {
312+ !& contains_buffer
313+ } else {
314+ contains_buffer. clone( )
315+ } ;
316+ let combined_validity = & needle_validity & & haystack_validity;
317+ Some ( NullBuffer :: new( combined_validity) )
318+ }
319+ } ;
320+
321+ Ok ( BooleanArray :: new( contains_buffer, result_nulls) )
322+ }
323+ }
324+ } ;
325+ }
326+
327+ string_static_filter ! ( Utf8StaticFilter , StringArray , as_string_opt:: <i32 >) ;
328+ string_static_filter ! (
329+ LargeUtf8StaticFilter ,
330+ LargeStringArray ,
331+ as_string_opt:: <i64 >
332+ ) ;
333+ string_static_filter ! ( Utf8ViewStaticFilter , StringViewArray , as_string_view_opt) ;
334+
187335impl ArrayStaticFilter {
188336 /// Computes a [`StaticFilter`] for the provided [`Array`] if there
189337 /// are nulls present or there are more than the configured number of
@@ -3962,24 +4110,40 @@ mod tests {
39624110 ) ;
39634111 }
39644112
3965- // Utf8 (falls through to ArrayStaticFilter)
3966- let utf8_in = Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) as ArrayRef ;
3967- let utf8_needle = Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) as ArrayRef ;
4113+ let string_cases = vec ! [
4114+ (
4115+ "Utf8" ,
4116+ Arc :: new( StringArray :: from( vec![ "a" , "b" , "c" ] ) ) as ArrayRef ,
4117+ Arc :: new( StringArray :: from( vec![ "a" , "d" , "b" ] ) ) as ArrayRef ,
4118+ ) ,
4119+ (
4120+ "LargeUtf8" ,
4121+ Arc :: new( LargeStringArray :: from( vec![ "a" , "b" , "c" ] ) ) as ArrayRef ,
4122+ Arc :: new( LargeStringArray :: from( vec![ "a" , "d" , "b" ] ) ) as ArrayRef ,
4123+ ) ,
4124+ (
4125+ "Utf8View" ,
4126+ Arc :: new( StringViewArray :: from( vec![ "a" , "b" , "c" ] ) ) as ArrayRef ,
4127+ Arc :: new( StringViewArray :: from( vec![ "a" , "d" , "b" ] ) ) as ArrayRef ,
4128+ ) ,
4129+ ] ;
39684130
3969- // Utf8 in_array, Utf8 needle
3970- assert_eq ! (
3971- expected,
3972- eval_in_list_from_array( Arc :: clone( & utf8_needle) , Arc :: clone( & utf8_in) , ) ?
3973- ) ;
4131+ for ( name, in_array, needle) in string_cases {
4132+ assert_eq ! (
4133+ expected,
4134+ eval_in_list_from_array( Arc :: clone( & needle) , Arc :: clone( & in_array) , ) ?,
4135+ "same-type failed for {name}"
4136+ ) ;
39744137
3975- // Utf8 in_array, Dict(Utf8) needle
3976- assert_eq ! (
3977- expected,
3978- eval_in_list_from_array(
3979- wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
3980- Arc :: clone( & utf8_in) ,
3981- ) ?
3982- ) ;
4138+ assert_eq ! (
4139+ expected,
4140+ eval_in_list_from_array( wrap_in_dict( needle) , in_array) ?,
4141+ "dict-needle failed for {name}"
4142+ ) ;
4143+ }
4144+
4145+ let utf8_in = Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) as ArrayRef ;
4146+ let utf8_needle = Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) as ArrayRef ;
39834147
39844148 // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
39854149 assert_eq ! (
@@ -4084,4 +4248,27 @@ mod tests {
40844248
40854249 Ok ( ( ) )
40864250 }
4251+
4252+ #[ test]
4253+ fn test_utf8_static_filter_avoids_string_copies ( ) -> Result < ( ) > {
4254+ let in_array = Arc :: new ( StringArray :: from ( vec ! [
4255+ Some ( "alpha" ) ,
4256+ Some ( "beta" ) ,
4257+ Some ( "alpha" ) ,
4258+ None ,
4259+ ] ) ) as ArrayRef ;
4260+
4261+ let filter = Utf8StaticFilter :: try_new ( & in_array) ?;
4262+
4263+ assert_eq ! ( filter. null_count( ) , 1 ) ;
4264+ assert_eq ! ( filter. map. len( ) , 2 ) ;
4265+
4266+ let needle = Arc :: new ( StringArray :: from ( vec ! [ Some ( "alpha" ) , Some ( "gamma" ) , None ] ) )
4267+ as ArrayRef ;
4268+
4269+ let result = filter. contains ( needle. as_ref ( ) , false ) ?;
4270+ assert_eq ! ( result, BooleanArray :: from( vec![ Some ( true ) , None , None ] ) ) ;
4271+
4272+ Ok ( ( ) )
4273+ }
40874274}
0 commit comments