Skip to content

Commit cc14c44

Browse files
committed
Updated to REF
1 parent fb10070 commit cc14c44

3 files changed

Lines changed: 375 additions & 25 deletions

File tree

config/train.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ train:
2626

2727
checkpoint: True # Enables checkpoint saving of model
2828
checkpoint_per_epoch: 1 # Save checkpoint per x epochs
29-
silent: False # Turn off progress tracking per iteration
29+
silent: False # Turn on progress tracking per iteration
30+
verbose: False # Turn on verbose progress tracking
3031
continue: False # Continue training with a pre-trained model
3132
finetune: False # Finetune a pre-trained model
3233

loader.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,25 +163,18 @@ def _collate_fn(batch):
163163
minibatch_size = len(batch)
164164
inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
165165
input_percentages = torch.FloatTensor(minibatch_size)
166-
target_sizes = torch.IntTensor(minibatch_size)
166+
target_sizes = np.zeros(minibatch_size, dtype=np.int32)
167167

168168
# TODO: Numpy broadcasting magic
169169
targets = []
170-
for x in range(minibatch_size):
171-
sample = batch[x]
172-
tensor = sample[0]
173-
target = sample[1]
174-
seq_length = tensor.size(1)
175-
176-
inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
177170

178-
input_percentages[x] = seq_length / float(max_seqlength)
179-
target_sizes[x] = len(target)
180-
targets.extend(target)
181-
targets = torch.IntTensor(targets)
182-
# TODO: Numpy broadcasting magic
171+
for x in range(minibatch_size):
172+
inputs[x][0].narrow(1, 0, batch[x][0].size(1)).copy_(batch[x][0])
173+
input_percentages[x] = batch[x][0].size(1) / float(max_seqlength)
174+
target_sizes[x] = len(batch[x][1])
175+
targets.extend(batch[x][1])
183176

184-
return inputs, targets, input_percentages, target_sizes
177+
return inputs, torch.IntTensor(targets), input_percentages, torch.from_numpy(target_sizes)
185178

186179

187180
class AudioDataLoader(DataLoader):

0 commit comments

Comments
 (0)