Skip to content

Commit 96b8f7e

Browse files
committed
Updated configuration
1 parent ac24d1d commit 96b8f7e

9 files changed

Lines changed: 111 additions & 466 deletions

File tree

config/infer.yaml

Whitespace-only changes.

config/test.yaml

Whitespace-only changes.

config/train.yaml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
train:
2+
train-manifest: 'examples/manifests/train_manifest.csv'
3+
val-manifest: 'examples/manifests/val_manifest.csv'
4+
labels-path: 'examples/labels.json' # Contains all characters for transcription
5+
log-dir: 'logs' # Location for log files
6+
def-dir: 'examples/checkpoints/', # Default location to save/load models
7+
model-name: 'deepspeech_final.pth' # File name to save the best model
8+
load-from: 'deepspeech_final.pth' # File name containing a checkpoint to continue/finetune
9+
10+
sample-rate: 16000 # Sample rate
11+
window-size: 0.02 # Window size for spectrogram in seconds
12+
window-stride: 0.01 # Window stride for spectrogram in seconds
13+
window: 'hamming' # Window type for spectrogram generation
14+
15+
batch-size: 32 # Batch size for training
16+
hidden-size: 800 # Hidden size of RNNs
17+
hidden-layers: 5 # Number of RNN layers
18+
rnn-type: 'gru' # Type of the RNN unit: gru|lstm are supported
19+
20+
max-epochs: 70 # Number of training epochs
21+
learning-rate: 3e-4 # Initial learning rate
22+
momentum: 0.9 # Momentum
23+
max-norm: 800 # Norm cutoff to prevent explosion of gradients
24+
learning-anneal: 1.1n # Annealing applied to learning rate every epoch
25+
sortaGrad: True # Turn on ordering of dataset on sequence length for the first epoch
26+
27+
checkpoint: True # Enables checkpoint saving of model
28+
checkpoint-per-epoch: 1 # Save checkpoint per x epochs
29+
silent: False # Turn off progress tracking per iteration
30+
continue: False # Continue training with a pre-trained model
31+
finetune: False # Finetune a pre-trained model
32+
33+
num-data-workers: 8 # Number of workers used in data-loading
34+
augment: False # Use random tempo and gain perturbations
35+
shuffle: True # Turn on shuffling and sample from dataset based on sequence length (smallest to largest)
36+
37+
seed: 123456 # Seed to generators
38+
cuda: True # Use cuda to train model
39+
half-precision: Trues # Uses half precision to train a model
40+
apex: True # Uses mixed precision to train a model
41+
static-loss-scaling: False # Static loss scale for mixed precision
42+
dynamic-loss-scaling: True # Use dynamic loss scaling for mixed precision
43+
44+
dist-url: 'tcp://127.0.0.1:1550' # URL used to set up distributed training
45+
dist-backend: 'nccl' # Distributed backend
46+
world-size: 1 # Number of distributed processes
47+
rank: 0 # The rank of the current process
48+
gpu-rank: 0 # If using distributed parallel for multi-gpu, sets the GPU for the process

infer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import argparse
2+
import os
3+
import wave
4+
from typing import Dict
5+
6+
import yaml
7+
8+
from modelwrapper import ModelWrapper
9+
10+
parser = argparse.ArgumentParser(description='ASR inference')
11+
parser.add_argument('--config', metavar='DIR',
12+
help='Path to inference config file', default='config/infer.yaml')
13+
14+
if __name__ == '__main__':
15+
args = parser.parse_args()
16+
with open(args.config, 'r') as file:
17+
config = yaml.load(file)
18+
config_dict: Dict = config["infer"]
19+
model = ModelWrapper(**config_dict)
20+
if "wave_path" in config_dict.keys() and os.path.isfile(config_dict.get("wave_path")):
21+
sound = wave.open(config_dict.get("wave_path"))
22+
print(model.infer(sound))
23+
else:
24+
print("Wave file not found!")

loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# ----------------------------------------------------------------------------
2+
# Based on SeanNaren's deepspeech.pytorch:
3+
# https://github.com/SeanNaren/deepspeech.pytorch
4+
# ----------------------------------------------------------------------------
5+
16
import math
27
import warnings
38
from typing import Tuple

models/deepspeech2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# ----------------------------------------------------------------------------
2+
# Based on SeanNaren's deepspeech.pytorch:
3+
# https://github.com/SeanNaren/deepspeech.pytorch
4+
# ----------------------------------------------------------------------------
5+
16
import math
27
from collections import OrderedDict
38

test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import yaml
5+
6+
from modelwrapper import ModelWrapper
7+
8+
parser = argparse.ArgumentParser(description='ASR testing')
9+
parser.add_argument('--config', metavar='DIR',
10+
help='Path to test config file', default='config/test.yaml')
11+
12+
if __name__ == '__main__':
13+
args = parser.parse_args()
14+
with open(args.config, 'r') as file:
15+
config = yaml.load(file)
16+
config_dict: Dict = config["test"]
17+
model = ModelWrapper(**config_dict)
18+
model.test()

0 commit comments

Comments
 (0)