Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer class #120

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions howl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AudioConfig(BaseModel):
"""Base config for loading audio file"""

sample_rate: int = 16000
use_mono: bool = True
mono: bool = True


class ContextConfig(BaseModel):
Expand All @@ -30,6 +30,8 @@ class ContextConfig(BaseModel):
token_type: str = "word"
# phone dictionary file path
phone_dictionary_path: str = None
# if True, [BLANK] token will be added to vocab (used for CTC loss)
use_blank: bool = False


class InferenceEngineConfig(BaseModel):
Expand All @@ -39,15 +41,19 @@ class InferenceEngineConfig(BaseModel):
per_frame: bool = False
# weighting on prediction (model output)
inference_weights: List[float] = None
# window size for a single prediction
window_ms: int = 500
# stride size
stride_ms: int = 50
# InferenceEngine says wake word is present
# if a sequence of predictions from the last INFERENCE_WINDOW_MS audio data matches the target sequence
inference_window_ms: float = 2000
inference_window_ms: int = 2000
# predictions are smoothed over SMOOTHING_WINDOW_MS before the final labels are computed
smoothing_window_ms: float = 50
smoothing_window_ms: int = 200
# negative labels are ignored as long as they don't last for TOLERANCE_WINDOW_MS
tolerance_window_ms: float = 500
tolerance_window_ms: int = 500
# prediction probability for positive labels must be above this threshold
inference_threshold: float = 0
inference_threshold: float = 0.5


class AudioTransformConfig(BaseModel):
Expand Down Expand Up @@ -79,9 +85,11 @@ class TrainingConfig(BaseModel):
batch_size: int = 16
learning_rate: float = 0.01
num_epochs: int = 10
eval_frequency: int = 5
lr_decay: float = 0.955
weight_decay: float = 0.00001
use_noise_dataset: bool = False
objective: str = "frame" # frame or ctc
noise_datasets: List[DatasetConfig] = []
train_datasets: List[DatasetConfig] = []
val_datasets: List[DatasetConfig] = []
Expand All @@ -91,6 +99,7 @@ class TrainingConfig(BaseModel):
model_config: ModelConfig = ModelConfig()
context_config: ContextConfig = ContextConfig()
workspace_path: str = None
device: str = "cpu"


class InferenceConfig(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions howl/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,5 @@ def load_from_config(config: ContextConfig):
token_type=config.token_type,
phone_dictionary_path=config.phone_dictionary_path,
seed=config.seed,
use_blank=config.use_blank,
)
10 changes: 0 additions & 10 deletions howl/data/common/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,3 @@ def labelled(self) -> bool:
def pin_memory(self):
"""Pin audio data in memory"""
self.audio_data.pin_memory()

def update_data(self, audio_data: torch.Tensor, label: FrameLabelData = None):
"""Update audio data and label

Args:
audio_data: new audio data
label: new label
"""
self.audio_data = audio_data
self.label = label
2 changes: 1 addition & 1 deletion howl/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def split(self, predicate_fn: Callable[[Any], bool]):
dataset_1 = deepcopy(self)
dataset_2 = deepcopy(self)
for metadata in self.metadata_list:
data_list = data_list2 if predicate_fn(metadata) else data_list1
data_list = data_list1 if predicate_fn(metadata) else data_list2
data_list.append(metadata)
dataset_1.metadata_list = data_list1
dataset_2.metadata_list = data_list2
Expand Down
110 changes: 85 additions & 25 deletions howl/data/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch.nn as nn
from torchaudio.transforms import ComputeDeltas, MelSpectrogram

from howl.data.common.example import EmplacableExample, WakeWordClipExample
from howl.data.common.example import EmplacableExample
from howl.data.common.sample import Sample
from howl.data.dataset.dataset import AudioClipDataset
from howl.data.transform.meyda import MeydaMelSpectrogram
from howl.settings import SETTINGS
Expand All @@ -24,11 +25,16 @@
"DatasetMixer",
"StandardAudioTransform",
"SpecAugmentTransform",
"NegativeSampleTransform",
]


# pylint: disable=invalid-name
# pylint: disable=unused-argument

# TODO: this file needs to be separated into three
# 1) audio augmentation
# 2) spectrogram augmentation
# 3) standard audio to spectrogram transform


@dataclass
Expand Down Expand Up @@ -84,39 +90,25 @@ def augment(self, param: AugmentationParameter, examples, **kwargs):

def passthrough(self, examples, **kwargs):
"""Skips the augmentation"""
# pylint: disable=unused-argument
return examples

def forward(self, x, **kwargs):
"""Apply augmentation in training model, otherwise skips the augmentation"""
for param in self.augment_params:
if param.enabled and self.rand.random() < param.prob and self.training:
x = self.augment(param, x, **kwargs)
if isinstance(x[0], Sample):
for sample in x:
self.augment_sample(param, sample, **kwargs)
# sample.audio_data = self.augment_audio_data(param, sample.audio_data, **kwargs)
# if sample.label is not None:
# sample.label = self.augment_label(param, sample.label, **kwargs)
else: # Example augmentation, to be deprecated
x = self.augment(param, x, **kwargs)
else:
x = self.passthrough(x, **kwargs)
return x


class NegativeSampleTransform(AugmentModule):
"""NegativeSampleTransform"""

@property
def default_params(self):
"""default_params"""
return (AugmentationParameter([0.2, 0.3, 0.4, 0.5], "chunk_size", 1, prob=0.3),)

@torch.no_grad()
def augment(self, param: AugmentationParameter, examples: Sequence[WakeWordClipExample], **kwargs):
"""augment"""
new_examples = []
for example in examples:
audio_data = example.audio_data[..., : int(example.audio_data.size(-1) * param.magnitude)]
example = example.update_audio_data(audio_data)
example.contains_wake_word = False
new_examples.append(example)
return new_examples


class TimeshiftTransform(AugmentModule):
"""Time-shift the audio data"""

Expand All @@ -142,6 +134,30 @@ def augment(self, param: AugmentationParameter, examples: Sequence[EmplacableExa
new_examples.append(example.update_audio_data(audio_data))
return new_examples

def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs):
"""Apply time shift (roll audio data)"""

audio_data_size = sample.audio_data.size(-1)

time_shift_mag = int(self.rand.random() * param.magnitude * self.sr)
if sample.audio_data.size(-1) < 2 * time_shift_mag:
time_shift_mag = int(0.5 * audio_data_size)

# direction
time_shift = time_shift_mag
if self.rand.random() < 0.5:
time_shift *= -1

sample.audio_data = torch.roll(sample.audio_data, time_shift)

# TODO: update labels if necessary
# new_timestamp_label_map = {}
# for timestamp, label in self.label_data.timestamp_label_map.items():
# new_timestamp = max(0, min(timestamp + time_shift, audio_data_size-1))
# new_timestamp_label_map[new_timestamp] = label
#
# sample.label.timestamp_label_map = new_timestamp_label_map


class TimestretchTransform(AugmentModule):
"""Time-stretch the audio data"""
Expand All @@ -164,6 +180,25 @@ def augment(self, param, examples: Sequence[EmplacableExample], **kwargs):
new_examples.append(example.update_audio_data(audio, scale=1 / rate))
return new_examples

@torch.no_grad()
def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs):
"""Apply time stretch"""

# Stretch factor. If rate > 1, then the signal is sped up. If rate < 1, then the signal is slowed down.
stretch_rate = np.clip(np.random.normal(1.0, param.magnitude), 0.3, 1.7)

sample.audio_data = torch.from_numpy(
librosa.effects.time_stretch(sample.audio_data.squeeze().cpu().numpy(), rate=stretch_rate)
)

# # TODO: update labels if necessary
# audio_data_size = sample.audio_data.size(-1)
# scale = 1 / stretch_rate
# new_timestamp_label_map = {}
# for timestamp, label in self.label_data.timestamp_label_map.items():
# new_timestamp = min(scale*timestamp, audio_data_size-1)
# new_timestamp_label_map[new_timestamp] = label


class NoiseTransform(AugmentModule):
"""Add synthetic noise to the audio data"""
Expand Down Expand Up @@ -195,6 +230,20 @@ def augment(self, param, examples: Sequence[EmplacableExample], **kwargs):
new_examples.append(example.update_audio_data(waveform))
return new_examples

@torch.no_grad()
def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs):
"""Apply noise"""
if param.name == "white":
strength = param.magnitude * self.rand.random()
noise_mask = torch.empty_like(sample.audio_data).normal_(0, strength)
else:
prob = param.magnitude * self.rand.random()
noise_mask = torch.empty_like(sample.audio_data).bernoulli_(prob / 2) - torch.empty_like(
sample.audio_data
).bernoulli_(prob / 2)
noise_mask.clamp_(-1, 1)
sample.audio_data = (sample.audio_data + noise_mask).clamp_(-1, 1)


class DatasetMixer(AugmentModule):
"""Augmentation by adding background noise"""
Expand Down Expand Up @@ -230,6 +279,18 @@ def augment(self, param, examples: Sequence[EmplacableExample], **kwargs):
new_examples.append(ex)
return new_examples

@torch.no_grad()
def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs):
"""Add background noise"""
bg_ex = self.rand.choice(self.dataset).audio_data.to(sample.audio_data.device)
while bg_ex.size(-1) < sample.audio_data.size(-1):
bg_ex = self.rand.choice(self.dataset).audio_data.to(sample.audio_data.device)
b = self.rand.randint(sample.audio_data.size(-1), bg_ex.size(-1))
a = b - sample.audio_data.size(-1)
bg_audio = bg_ex[..., a:b]
alpha = 1 if param.name == "replace" else self.rand.random() * param.magnitude
sample.audio_data = sample.audio_data * (1 - alpha) + bg_audio * alpha


class StandardAudioTransform(AugmentModule):
"""Transformation to apply on the audio data"""
Expand Down Expand Up @@ -417,7 +478,6 @@ class VtlpMelScale(nn.Module):

def __init__(self, n_mels=128, sample_rate=16000, f_min=0.0, f_max=None, n_stft=None):
"""__init__"""
# pylint: disable=unused-argument
super().__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
Expand Down
8 changes: 3 additions & 5 deletions howl/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ def infer(self, audio_data: torch.Tensor) -> bool:
class FrameInferenceEngine(InferenceEngine):
"""InferenceEngine that evaluates the given audio data by generating predictions frame by frame"""

def __init__(self, max_window_size_ms: int, eval_stride_size_ms: int, *args):
def __init__(self, window_ms: int, eval_stride_size_ms: int, *args):
"""Initialize FrameInferenceEngine"""
super().__init__(*args)
self.max_window_size_ms, self.eval_stride_size_ms = max_window_size_ms, eval_stride_size_ms
self.window_ms, self.eval_stride_size_ms = window_ms, eval_stride_size_ms

@torch.no_grad()
def infer(self, audio_data: torch.Tensor) -> bool:
Expand All @@ -231,9 +231,7 @@ def infer(self, audio_data: torch.Tensor) -> bool:
return True if wake word presents in the last window
"""
sequence_present = False
for window in audio_utils.stride(
audio_data, self.max_window_size_ms, self.eval_stride_size_ms, self.sample_rate
):
for window in audio_utils.stride(audio_data, self.window_ms, self.eval_stride_size_ms, self.sample_rate):
if window.size(-1) < 1000:
break
self.ingest_frame(window.squeeze(0), self.curr_time)
Expand Down
Loading