@@ -163,25 +163,18 @@ def _collate_fn(batch):
163163 minibatch_size = len (batch )
164164 inputs = torch .zeros (minibatch_size , 1 , freq_size , max_seqlength )
165165 input_percentages = torch .FloatTensor (minibatch_size )
166- target_sizes = torch . IntTensor (minibatch_size )
166+ target_sizes = np . zeros (minibatch_size , dtype = np . int32 )
167167
168168 # TODO: Numpy broadcasting magic
169169 targets = []
170- for x in range (minibatch_size ):
171- sample = batch [x ]
172- tensor = sample [0 ]
173- target = sample [1 ]
174- seq_length = tensor .size (1 )
175-
176- inputs [x ][0 ].narrow (1 , 0 , seq_length ).copy_ (tensor )
177170
178- input_percentages [ x ] = seq_length / float ( max_seqlength )
179- target_sizes [x ] = len ( target )
180- targets . extend ( target )
181- targets = torch . IntTensor ( targets )
182- # TODO: Numpy broadcasting magic
171+ for x in range ( minibatch_size ):
172+ inputs [x ][ 0 ]. narrow ( 1 , 0 , batch [ x ][ 0 ]. size ( 1 )). copy_ ( batch [ x ][ 0 ] )
173+ input_percentages [ x ] = batch [ x ][ 0 ]. size ( 1 ) / float ( max_seqlength )
174+ target_sizes [ x ] = len ( batch [ x ][ 1 ] )
175+ targets . extend ( batch [ x ][ 1 ])
183176
184- return inputs , targets , input_percentages , target_sizes
177+ return inputs , torch . IntTensor ( targets ) , input_percentages , torch . from_numpy ( target_sizes )
185178
186179
187180class AudioDataLoader (DataLoader ):
0 commit comments