@@ -166,6 +166,11 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo
166166 throw new ValueError ( "mask cannot be scalar." ) ;
167167
168168 var leading_size = gen_math_ops . prod ( shape ( tensor_tensor ) [ $ "{ axis } :{ axis + ndims_mask } "] , ops . convert_to_tensor ( new [ ] { 0 } ) ) ;
169+ if ( leading_size . rank == 0 )
170+ {
171+ leading_size = expand_dims ( leading_size , 0 ) ;
172+ }
173+
169174 var shape1 = concat ( new [ ]
170175 {
171176 shape ( tensor_tensor ) [ $ ":{ axis } "] ,
@@ -185,7 +190,7 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo
185190
186191 private static Tensor _apply_mask_1d ( Tensor reshaped_tensor , Tensor mask , int axis = 0 )
187192 {
188- var indices = squeeze ( where ( mask ) , axis : new [ ] { 1 } ) ;
193+ var indices = squeeze ( where_v2 ( mask ) , axis : new [ ] { 1 } ) ;
189194 return gather ( reshaped_tensor , indices , axis : ops . convert_to_tensor ( axis ) ) ;
190195 }
191196
@@ -940,12 +945,12 @@ public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
940945 /// <returns></returns>
941946 public static Tensor concat ( Tensor [ ] values , Tensor axis , string name = "concat" )
942947 {
943- return tf . Context . ExecuteOp ( "ConcatV2" , name , new ExecuteOpArgs ( values , axis ) ) ;
948+ return gen_array_ops . concat_v2 ( values , axis , name : name ) ;
944949 }
945950
946- public static Tensor concat ( object [ ] values , int axis , string name = "concat" )
951+ public static Tensor concat ( Tensor [ ] values , Axis axis , string name = "concat" )
947952 {
948- return tf . Context . ExecuteOp ( "ConcatV2" , name , new ExecuteOpArgs ( values , axis ) ) ;
953+ return gen_array_ops . concat_v2 ( values , axis , name : name ) ;
949954 }
950955
951956 /// <summary>
0 commit comments