@@ -14,10 +14,15 @@ You may obtain a copy of the License at
1414 limitations under the License.
1515******************************************************************************/
1616
17+ using Serilog . Debugging ;
1718using System ;
19+ using System . Collections . Concurrent ;
1820using System . Collections . Generic ;
21+ //using System.ComponentModel.DataAnnotations;
1922using System . Text ;
23+ using System . Xml . Linq ;
2024using Tensorflow . Framework ;
25+ using Tensorflow . NumPy ;
2126using static Tensorflow . Binding ;
2227
2328namespace Tensorflow
@@ -99,5 +104,55 @@ public static RowPartition from_row_splits(Tensor row_splits,
99104 return new RowPartition ( row_splits ) ;
100105 } ) ;
101106 }
107+
108+ public static RowPartition from_row_lengths ( Tensor row_lengths ,
109+ bool validate = true ,
110+ TF_DataType dtype = TF_DataType . TF_INT32 ,
111+ TF_DataType dtype_hint = TF_DataType . TF_INT32 )
112+ {
113+ row_lengths = _convert_row_partition (
114+ row_lengths , "row_lengths" , dtype_hint : dtype_hint , dtype : dtype ) ;
115+ Tensor row_limits = math_ops . cumsum < Tensor > ( row_lengths , tf . constant ( - 1 ) ) ;
116+ Tensor row_splits = array_ops . concat ( new Tensor [ ] { tf . convert_to_tensor ( np . array ( new int [ ] { 0 } , TF_DataType . TF_INT64 ) ) , row_limits } , axis : 0 ) ;
117+ return new RowPartition ( row_splits : row_splits , row_lengths : row_lengths ) ;
118+ }
119+
120+ public static Tensor _convert_row_partition ( Tensor partition , string name , TF_DataType dtype ,
121+ TF_DataType dtype_hint = TF_DataType . TF_INT64 )
122+ {
123+ if ( partition is NDArray && partition . GetDataType ( ) == np . int32 ) partition = ops . convert_to_tensor ( partition , name : name ) ;
124+ if ( partition . GetDataType ( ) != np . int32 && partition . GetDataType ( ) != np . int64 ) throw new ValueError ( $ "{ name } must have dtype int32 or int64") ;
125+ return partition ;
126+ }
127+
128+ public Tensor nrows ( )
129+ {
130+ /*Returns the number of rows created by this `RowPartition*/
131+ if ( this . _nrows != null ) return this . _nrows ;
132+ var nsplits = tensor_shape . dimension_at_index ( this . _row_splits . shape , 0 ) ;
133+ if ( nsplits == null ) return array_ops . shape ( this . _row_splits , out_type : this . row_splits . dtype ) [ 0 ] - 1 ;
134+ else return constant_op . constant ( nsplits . value - 1 , dtype : this . row_splits . dtype ) ;
135+ }
136+
137+ public Tensor row_lengths ( )
138+ {
139+
140+ if ( this . _row_splits != null )
141+ {
142+ int nrows_plus_one = tensor_shape . dimension_value ( this . _row_splits . shape [ 0 ] ) ;
143+ return tf . constant ( nrows_plus_one - 1 ) ;
144+
145+ }
146+ if ( this . _row_lengths != null )
147+ {
148+ var nrows = tensor_shape . dimension_value ( this . _row_lengths . shape [ 0 ] ) ;
149+ return tf . constant ( nrows ) ;
150+ }
151+ if ( this . _nrows != null )
152+ {
153+ return tensor_util . constant_value ( this . _nrows ) ;
154+ }
155+ return tf . constant ( - 1 ) ;
156+ }
102157 }
103158}
0 commit comments