diff --git a/howl/config.py b/howl/config.py index 8c428ad2..f2ff7975 100644 --- a/howl/config.py +++ b/howl/config.py @@ -15,7 +15,7 @@ class AudioConfig(BaseModel): """Base config for loading audio file""" sample_rate: int = 16000 - use_mono: bool = True + mono: bool = True class ContextConfig(BaseModel): @@ -30,6 +30,8 @@ class ContextConfig(BaseModel): token_type: str = "word" # phone dictionary file path phone_dictionary_path: str = None + # if True, [BLANK] token will be added to vocab (used for CTC loss) + use_blank: bool = False class InferenceEngineConfig(BaseModel): @@ -39,15 +41,19 @@ class InferenceEngineConfig(BaseModel): per_frame: bool = False # weighting on prediction (model output) inference_weights: List[float] = None + # window size for a single prediction + window_ms: int = 500 + # stride size + stride_ms: int = 50 # InferenceEngine says wake word is present # if a sequence of predictions from the last INFERENCE_WINDOW_MS audio data matches the target sequence - inference_window_ms: float = 2000 + inference_window_ms: int = 2000 # predictions are smoothed over SMOOTHING_WINDOW_MS before the final labels are computed - smoothing_window_ms: float = 50 + smoothing_window_ms: int = 200 # negative labels are ignored as long as they don't last for TOLERANCE_WINDOW_MS - tolerance_window_ms: float = 500 + tolerance_window_ms: int = 500 # prediction probability for positive labels must be above this threshold - inference_threshold: float = 0 + inference_threshold: float = 0.5 class AudioTransformConfig(BaseModel): @@ -79,9 +85,11 @@ class TrainingConfig(BaseModel): batch_size: int = 16 learning_rate: float = 0.01 num_epochs: int = 10 + eval_frequency: int = 5 lr_decay: float = 0.955 weight_decay: float = 0.00001 use_noise_dataset: bool = False + objective: str = "frame" # frame or ctc noise_datasets: List[DatasetConfig] = [] train_datasets: List[DatasetConfig] = [] val_datasets: List[DatasetConfig] = [] @@ -91,6 +99,7 @@ class TrainingConfig(BaseModel): model_config: ModelConfig = ModelConfig() context_config: ContextConfig = ContextConfig() workspace_path: str = None + device: str = "cpu" class InferenceConfig(BaseModel): diff --git a/howl/context.py b/howl/context.py index 69c60486..0be14aab 100644 --- a/howl/context.py +++ b/howl/context.py @@ -122,4 +122,5 @@ def load_from_config(config: ContextConfig): token_type=config.token_type, phone_dictionary_path=config.phone_dictionary_path, seed=config.seed, + use_blank=config.use_blank, ) diff --git a/howl/data/common/sample.py b/howl/data/common/sample.py index e704daa9..eb8f6862 100644 --- a/howl/data/common/sample.py +++ b/howl/data/common/sample.py @@ -31,13 +31,3 @@ def labelled(self) -> bool: def pin_memory(self): """Pin audio data in memory""" self.audio_data.pin_memory() - - def update_data(self, audio_data: torch.Tensor, label: FrameLabelData = None): - """Update audio data and label - - Args: - audio_data: new audio data - label: new label - """ - self.audio_data = audio_data - self.label = label diff --git a/howl/data/dataset/dataset.py b/howl/data/dataset/dataset.py index 0f9f0be4..06033e45 100644 --- a/howl/data/dataset/dataset.py +++ b/howl/data/dataset/dataset.py @@ -122,7 +122,7 @@ def split(self, predicate_fn: Callable[[Any], bool]): dataset_1 = deepcopy(self) dataset_2 = deepcopy(self) for metadata in self.metadata_list: - data_list = data_list2 if predicate_fn(metadata) else data_list1 + data_list = data_list1 if predicate_fn(metadata) else data_list2 data_list.append(metadata) dataset_1.metadata_list = data_list1 dataset_2.metadata_list = data_list2 diff --git a/howl/data/transform/transform.py b/howl/data/transform/transform.py index 502b9150..f24550b2 100644 --- a/howl/data/transform/transform.py +++ b/howl/data/transform/transform.py @@ -9,7 +9,8 @@ import torch.nn as nn from torchaudio.transforms import ComputeDeltas, MelSpectrogram -from howl.data.common.example import EmplacableExample, WakeWordClipExample +from howl.data.common.example import EmplacableExample +from howl.data.common.sample import Sample from howl.data.dataset.dataset import AudioClipDataset from howl.data.transform.meyda import MeydaMelSpectrogram from howl.settings import SETTINGS @@ -24,11 +25,16 @@ "DatasetMixer", "StandardAudioTransform", "SpecAugmentTransform", - "NegativeSampleTransform", ] # pylint: disable=invalid-name +# pylint: disable=unused-argument + +# TODO: this file needs to be separated into three +# 1) audio augmentation +# 2) spectrogram augmentation +# 3) standard audio to spectrogram transform @dataclass @@ -84,39 +90,25 @@ def augment(self, param: AugmentationParameter, examples, **kwargs): def passthrough(self, examples, **kwargs): """Skips the augmentation""" - # pylint: disable=unused-argument return examples def forward(self, x, **kwargs): """Apply augmentation in training model, otherwise skips the augmentation""" for param in self.augment_params: if param.enabled and self.rand.random() < param.prob and self.training: - x = self.augment(param, x, **kwargs) + if isinstance(x[0], Sample): + for sample in x: + self.augment_sample(param, sample, **kwargs) + # sample.audio_data = self.augment_audio_data(param, sample.audio_data, **kwargs) + # if sample.label is not None: + # sample.label = self.augment_label(param, sample.label, **kwargs) + else: # Example augmentation, to be deprecated + x = self.augment(param, x, **kwargs) else: x = self.passthrough(x, **kwargs) return x -class NegativeSampleTransform(AugmentModule): - """NegativeSampleTransform""" - - @property - def default_params(self): - """default_params""" - return (AugmentationParameter([0.2, 0.3, 0.4, 0.5], "chunk_size", 1, prob=0.3),) - - @torch.no_grad() - def augment(self, param: AugmentationParameter, examples: Sequence[WakeWordClipExample], **kwargs): - """augment""" - new_examples = [] - for example in examples: - audio_data = example.audio_data[..., : int(example.audio_data.size(-1) * param.magnitude)] - example = example.update_audio_data(audio_data) - example.contains_wake_word = False - new_examples.append(example) - return new_examples - - class TimeshiftTransform(AugmentModule): """Time-shift the audio data""" @@ -142,6 +134,30 @@ def augment(self, param: AugmentationParameter, examples: Sequence[EmplacableExa new_examples.append(example.update_audio_data(audio_data)) return new_examples + def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs): + """Apply time shift (roll audio data)""" + + audio_data_size = sample.audio_data.size(-1) + + time_shift_mag = int(self.rand.random() * param.magnitude * self.sr) + if sample.audio_data.size(-1) < 2 * time_shift_mag: + time_shift_mag = int(0.5 * audio_data_size) + + # direction + time_shift = time_shift_mag + if self.rand.random() < 0.5: + time_shift *= -1 + + sample.audio_data = torch.roll(sample.audio_data, time_shift) + + # TODO: update labels if necessary + # new_timestamp_label_map = {} + # for timestamp, label in self.label_data.timestamp_label_map.items(): + # new_timestamp = max(0, min(timestamp + time_shift, audio_data_size-1)) + # new_timestamp_label_map[new_timestamp] = label + # + # sample.label.timestamp_label_map = new_timestamp_label_map + class TimestretchTransform(AugmentModule): """Time-stretch the audio data""" @@ -164,6 +180,25 @@ def augment(self, param, examples: Sequence[EmplacableExample], **kwargs): new_examples.append(example.update_audio_data(audio, scale=1 / rate)) return new_examples + @torch.no_grad() + def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs): + """Apply time stretch""" + + # Stretch factor. If rate > 1, then the signal is sped up. If rate < 1, then the signal is slowed down. + stretch_rate = np.clip(np.random.normal(1.0, param.magnitude), 0.3, 1.7) + + sample.audio_data = torch.from_numpy( + librosa.effects.time_stretch(sample.audio_data.squeeze().cpu().numpy(), rate=stretch_rate) + ) + + # # TODO: update labels if necessary + # audio_data_size = sample.audio_data.size(-1) + # scale = 1 / stretch_rate + # new_timestamp_label_map = {} + # for timestamp, label in self.label_data.timestamp_label_map.items(): + # new_timestamp = min(scale*timestamp, audio_data_size-1) + # new_timestamp_label_map[new_timestamp] = label + class NoiseTransform(AugmentModule): """Add synthetic noise to the audio data""" @@ -195,6 +230,20 @@ def augment(self, param, examples: Sequence[EmplacableExample], **kwargs): new_examples.append(example.update_audio_data(waveform)) return new_examples + @torch.no_grad() + def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs): + """Apply noise""" + if param.name == "white": + strength = param.magnitude * self.rand.random() + noise_mask = torch.empty_like(sample.audio_data).normal_(0, strength) + else: + prob = param.magnitude * self.rand.random() + noise_mask = torch.empty_like(sample.audio_data).bernoulli_(prob / 2) - torch.empty_like( + sample.audio_data + ).bernoulli_(prob / 2) + noise_mask.clamp_(-1, 1) + sample.audio_data = (sample.audio_data + noise_mask).clamp_(-1, 1) + class DatasetMixer(AugmentModule): """Augmentation by adding background noise""" @@ -230,6 +279,18 @@ def augment(self, param, examples: Sequence[EmplacableExample], **kwargs): new_examples.append(ex) return new_examples + @torch.no_grad() + def augment_sample(self, param: AugmentationParameter, sample: Sample, **kwargs): + """Add background noise""" + bg_ex = self.rand.choice(self.dataset).audio_data.to(sample.audio_data.device) + while bg_ex.size(-1) < sample.audio_data.size(-1): + bg_ex = self.rand.choice(self.dataset).audio_data.to(sample.audio_data.device) + b = self.rand.randint(sample.audio_data.size(-1), bg_ex.size(-1)) + a = b - sample.audio_data.size(-1) + bg_audio = bg_ex[..., a:b] + alpha = 1 if param.name == "replace" else self.rand.random() * param.magnitude + sample.audio_data = sample.audio_data * (1 - alpha) + bg_audio * alpha + class StandardAudioTransform(AugmentModule): """Transformation to apply on the audio data""" @@ -417,7 +478,6 @@ class VtlpMelScale(nn.Module): def __init__(self, n_mels=128, sample_rate=16000, f_min=0.0, f_max=None, n_stft=None): """__init__""" - # pylint: disable=unused-argument super().__init__() self.n_mels = n_mels self.sample_rate = sample_rate diff --git a/howl/model/inference.py b/howl/model/inference.py index 91960c8e..7290373d 100644 --- a/howl/model/inference.py +++ b/howl/model/inference.py @@ -214,10 +214,10 @@ def infer(self, audio_data: torch.Tensor) -> bool: class FrameInferenceEngine(InferenceEngine): """InferenceEngine that evaluates the given audio data by generating predictions frame by frame""" - def __init__(self, max_window_size_ms: int, eval_stride_size_ms: int, *args): + def __init__(self, window_ms: int, eval_stride_size_ms: int, *args): """Initialize FrameInferenceEngine""" super().__init__(*args) - self.max_window_size_ms, self.eval_stride_size_ms = max_window_size_ms, eval_stride_size_ms + self.window_ms, self.eval_stride_size_ms = window_ms, eval_stride_size_ms @torch.no_grad() def infer(self, audio_data: torch.Tensor) -> bool: @@ -231,9 +231,7 @@ def infer(self, audio_data: torch.Tensor) -> bool: return True if wake word presents in the last window """ sequence_present = False - for window in audio_utils.stride( - audio_data, self.max_window_size_ms, self.eval_stride_size_ms, self.sample_rate - ): + for window in audio_utils.stride(audio_data, self.window_ms, self.eval_stride_size_ms, self.sample_rate): if window.size(-1) < 1000: break self.ingest_frame(window.squeeze(0), self.curr_time) diff --git a/howl/trainer.py b/howl/trainer.py index 0dd0e7a0..25cf6c0b 100644 --- a/howl/trainer.py +++ b/howl/trainer.py @@ -1,9 +1,40 @@ -import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from devtools import debug as print_debug +from torch.optim.adamw import AdamW +from torch.optim.lr_scheduler import ExponentialLR +from tqdm import tqdm, trange import howl -from howl.config import TrainingConfig +from howl.config import DatasetConfig, TrainingConfig from howl.context import InferenceContext -from howl.utils import logging_utils +from howl.data.common.tokenizer import WakeWordTokenizer +from howl.data.dataloader import StandardAudioDataLoaderBuilder +from howl.data.dataset.dataset import AudioClipDataset, DatasetSplit, DatasetType, WakeWordDataset +from howl.data.dataset.dataset_loader import RecursiveNoiseDatasetLoader +from howl.data.transform.batchifier import AudioSequenceBatchifier, WakeWordFrameBatchifier +from howl.data.transform.operator import ZmuvTransform, batchify, compose +from howl.data.transform.transform import ( + AugmentModule, + DatasetMixer, + NoiseTransform, + SpecAugmentTransform, + StandardAudioTransform, + TimeshiftTransform, + TimestretchTransform, +) +from howl.dataset.audio_dataset_constants import AudioDatasetType, SampleType +from howl.dataset_loader.howl_audio_dataset_loader import HowlAudioDatasetLoader +from howl.model import ConfusionMatrix, RegisteredModel +from howl.model.inference import FrameInferenceEngine, InferenceEngine +from howl.utils import hash_utils +from howl.utils.logger import Logger +from howl.workspace import Workspace # WIP; please use train.py @@ -12,32 +43,414 @@ class Trainer: """Class which defines training logics""" def __init__( - self, training_cfg: TrainingConfig, logger: logging.Logger = None, + self, training_cfg: TrainingConfig, ): """Initialize trainer Args: training_cfg (TrainingConfig): training config that defines how to load datasets and train the model - logger (logging.Logger): logger """ + Logger.info(training_cfg) + # TODO: Config should be printed out in a cleaner format on terminal + print_debug(training_cfg) self.training_cfg = training_cfg self.context_cfg = training_cfg.context_config + self.inference_engine_cfg = training_cfg.inference_engine_config + self.device = torch.device(self.training_cfg.device) + + self.inference_engine_cfg.per_frame = self.training_cfg.objective == "frame" + self.context_cfg.use_blank = self.training_cfg.objective == "ctc" self.context = InferenceContext.load_from_config(self.context_cfg) - if logger is None: - self.logger = logging_utils.setup_logger(self.__class__.__name__) + # TODO: Ideally, WakeWordDataset needs to be deprecated + self.train_dataset: WakeWordDataset = WakeWordDataset( + metadata_list=[], + set_type=DatasetType.TRAINING, + dataset_split=DatasetSplit.TRAINING, + frame_labeler=self.context.labeler, + ) + self.dev_pos_dataset: WakeWordDataset = WakeWordDataset( + metadata_list=[], + set_type=DatasetType.DEV, + dataset_split=DatasetSplit.TRAINING, + frame_labeler=self.context.labeler, + ) + self.dev_neg_dataset: WakeWordDataset = WakeWordDataset( + metadata_list=[], + set_type=DatasetType.DEV, + dataset_split=DatasetSplit.TRAINING, + frame_labeler=self.context.labeler, + ) + self.test_pos_dataset: WakeWordDataset = WakeWordDataset( + metadata_list=[], + set_type=DatasetType.TEST, + dataset_split=DatasetSplit.TEST, + frame_labeler=self.context.labeler, + ) + self.test_neg_dataset: WakeWordDataset = WakeWordDataset( + metadata_list=[], + set_type=DatasetType.TEST, + dataset_split=DatasetSplit.TEST, + frame_labeler=self.context.labeler, + ) + + self.noise_datasets: Dict[str, AudioClipDataset] = { + DatasetSplit.TRAINING: None, + DatasetSplit.DEV: None, + DatasetSplit.TEST: None, + } + + self.inference_engine: InferenceEngine = None + + self.audio_transform: StandardAudioTransform = None + self.zmuv_transform: ZmuvTransform = None + self.audio_augmentations: List[AugmentModule] = [] + self.spectrogram_augmentations: List[AugmentModule] = [] + + self.model: nn.Module = None + + def _load_dataset(self, dataset_split: DatasetSplit, dataset_cfg: DatasetConfig): + """Load a dataset given dataset config""" + dataset_loader = HowlAudioDatasetLoader(AudioDatasetType.ALIGNED, Path(dataset_cfg.path).expanduser()) + ds_kwargs = dataset_cfg.audio_config.dict() + ds_kwargs["labeler"] = self.context.labeler + dataset = dataset_loader.load_split(dataset_split, **ds_kwargs) + # dataset.print_stats(header=dataset_cfg.path, word_searcher=self.context.searcher, compute_length=True) + + return dataset + + def _prepare_train_dataset(self): + """Load train datasets""" + for dataset_cfg in self.training_cfg.train_datasets: + dataset = self._load_dataset(DatasetSplit.TRAINING, dataset_cfg) + self.train_dataset.extend(dataset) + + def _prepare_dev_dataset(self): + """Load dev datasets and store them into appropriate variable (positive, negative)""" + for dataset_cfg in self.training_cfg.train_datasets: + dataset = self._load_dataset(DatasetSplit.DEV, dataset_cfg) + + if SampleType.POSITIVE in dataset_cfg.path: + self.dev_pos_dataset.extend(dataset) + elif SampleType.NEGATIVE in dataset_cfg.path: + self.dev_neg_dataset.extend(dataset) + else: + dev_pos_dataset = dataset.filter(lambda x: self.context.searcher.search(x.transcription), clone=True) + self.dev_pos_dataset.extend(dev_pos_dataset) + dev_neg_dataset = dataset.filter( + lambda x: not self.context.searcher.search(x.transcription), clone=True + ) + self.dev_neg_dataset.extend(dev_neg_dataset) + + def _prepare_test_dataset(self): + """Load test datasets and store them into appropriate variable (positive, negative)""" + for dataset_cfg in self.training_cfg.train_datasets: + dataset = self._load_dataset(DatasetSplit.DEV, dataset_cfg) + + if SampleType.POSITIVE in dataset_cfg.path: + self.test_pos_dataset.extend(dataset) + elif SampleType.NEGATIVE in dataset_cfg.path: + self.test_neg_dataset.extend(dataset) + else: + test_pos_dataset = dataset.filter(lambda x: self.context.searcher.search(x.transcription), clone=True) + self.test_pos_dataset.extend(test_pos_dataset) + test_neg_dataset = dataset.filter( + lambda x: not self.context.searcher.search(x.transcription), clone=True + ) + self.test_neg_dataset.extend(test_neg_dataset) + + def _prepare_noise_dataset(self): + """Load noise dataset for audio augmentation""" + + for idx, noise_dataset_cfg in enumerate(self.training_cfg.noise_datasets): + noise_ds = RecursiveNoiseDatasetLoader().load( + Path(noise_dataset_cfg.path).expanduser(), + sample_rate=noise_dataset_cfg.audio_config.sample_rate, + mono=noise_dataset_cfg.audio_config.mono, + ) + # 80, 10, 10 split + noise_ds_train, noise_ds_dev_test = noise_ds.split(hash_utils.Sha256Splitter(80)) + noise_ds_dev, noise_ds_test = noise_ds_dev_test.split(hash_utils.Sha256Splitter(90)) + + if idx == 0: + self.noise_datasets[DatasetSplit.TRAINING] = noise_ds_train + self.noise_datasets[DatasetSplit.DEV] = noise_ds_dev + self.noise_datasets[DatasetSplit.TEST] = noise_ds_test + else: + self.noise_datasets[DatasetSplit.TRAINING].extend(noise_ds_train) + self.noise_datasets[DatasetSplit.DEV].extend(noise_ds_dev) + self.noise_datasets[DatasetSplit.TEST].extend(noise_ds_test) + + for dataset_split in [DatasetSplit.TRAINING, DatasetSplit.DEV, DatasetSplit.TEST]: + Logger.info( + f"Loaded {len(self.noise_datasets[dataset_split].metadata_list)} noise files for {dataset_split}" + ) + + def _prepare_audio_augmentations(self): + """Instantiate a set of audio augmentations""" + self.audio_transform = StandardAudioTransform().to(self.device).eval() + self.zmuv_transform = ZmuvTransform().to(self.device) + + if self.training_cfg.objective == "frame": + batchifier = WakeWordFrameBatchifier( + self.context.negative_label, window_size_ms=self.inference_engine_cfg.window_ms + ) + else: + tokenizer = WakeWordTokenizer(self.context.vocab, ignore_oov=False) + batchifier = AudioSequenceBatchifier(self.context.negative_label, tokenizer) + + if self.training_cfg.use_noise_dataset: + self.audio_augmentations = [DatasetMixer(self.noise_datasets[DatasetSplit.TRAINING]).train()] + + self.audio_augmentations.extend( + [TimestretchTransform().train(), TimeshiftTransform().train(), NoiseTransform().train(), batchifier] + ) + + def _prepare_spectrogram_augmentations(self): + """Instantiate a set of spectrogram augmentations""" + self.spectrogram_augmentations = [SpecAugmentTransform().train()] + + def _train_zmuv_model(self, workspace: Workspace, num_batch_to_consider: int = 2000): + """Train or load ZMUV model""" + zmuv_dl = StandardAudioDataLoaderBuilder(self.train_dataset, collate_fn=batchify).build(1) + zmuv_dl.shuffle = True + + load_pretrained_model = Path(workspace.zmuv_model_path()).exists() + + if load_pretrained_model: + self.zmuv_transform.load_state_dict(torch.load(workspace.zmuv_model_path())) + else: + for idx, batch in enumerate(tqdm(zmuv_dl, desc="Constructing ZMUV model", total=num_batch_to_consider)): + batch.to(self.device) + self.zmuv_transform.update(self.audio_transform(batch.audio_data)) + + # We just need to approximate mean and variance + if idx == num_batch_to_consider: + break + + zmuv_mean = self.zmuv_transform.mean.item() + workspace.summary_writer.add_scalar("Meta/ZMUV_mean", zmuv_mean) + + zmuv_std = self.zmuv_transform.std.item() + workspace.summary_writer.add_scalar("Meta/ZMUV_std", zmuv_std) + + Logger.info(f"zmuv_mean: {zmuv_mean}, zmuv_std: {zmuv_std}") + + if not load_pretrained_model: + torch.save(self.zmuv_transform.state_dict(), workspace.zmuv_model_path()) + + def _prepare_models(self, workspace: Workspace, load_pretrained_model: bool = False): + # model for normalization + self._train_zmuv_model(workspace) + + # model for kws + self.model = ( + RegisteredModel.find_registered_class(self.training_cfg.model_config.architecture)(self.context.num_labels) + .to(self.device) + .streaming() + ) + + if load_pretrained_model: + workspace.load_model(self.model, best=False) + + def train(self, load_dataset: bool = True, continue_training: bool = False, debug: bool = False): + """ + Train the model on train datasets. + """ + # pylint: disable=too-many-statements + # pylint: disable=too-many-branches + + if debug: + self.training_cfg.workspace_path = ( + f"{howl.workspaces_path()}/{self.context.wake_word.replace(' ', '_')}-debug" + ) if self.training_cfg.workspace_path is None: - print(howl.workspaces_path() / self.context.wake_word.replace(" ", "_")) - - # def train(self): - # """ - # Train the model on train datasets. - # """ - # raise NotImplementedError() - # - # def validation(self): - # """ - # Validate the model on validation datasets. - # """ - # raise NotImplementedError() + if continue_training: + raise RuntimeError("workspace_path should be specified when continue_training flag enabled") + curr_date_time = datetime.now() + self.training_cfg.workspace_path = ( + f"{howl.workspaces_path()}/{self.context.wake_word.replace(' ', '_')}-" + f"{curr_date_time.strftime('%m_%d_%H_%M')}" + ) + + Logger.info(f"Workspace: {self.training_cfg.workspace_path}") + + workspace = Workspace(Path(self.training_cfg.workspace_path), delete_existing=(not debug)) + writer = workspace.summary_writer + + # Prepare datasets + Logger.heading("Dataset preparation") + if load_dataset: + self._prepare_train_dataset() + self._prepare_dev_dataset() + self._prepare_test_dataset() + + # TODO: print dataset stats + # ww_dev_pos_ds.print_stats(header="dev_pos", word_searcher=ctx.searcher, compute_length=True) + + if self.training_cfg.use_noise_dataset: + self._prepare_noise_dataset() + + # Audio data augmentation + self._prepare_audio_augmentations() + audio_aug_comp = compose(*self.audio_augmentations) + + # Spectrogram augmentation + self._prepare_spectrogram_augmentations() + spec_aug_comp = compose(*self.spectrogram_augmentations) + + # prepare_models + Logger.heading("Model preparation") + self._prepare_models(workspace, load_pretrained_model=continue_training) + + # prepare inference engine + if self.inference_engine_cfg.per_frame: + self.inference_engine = FrameInferenceEngine( + self.inference_engine_cfg.window_ms, + self.inference_engine_cfg.stride_ms, + self.model, + self.zmuv_transform, + self.context, + ) + else: + self.inference_engine = InferenceEngine(self.model, self.zmuv_transform, self.context) + + # Training kws model + Logger.heading("Model training") + + if self.training_cfg.objective == "frame": + criterion = nn.CrossEntropyLoss() + else: + criterion = nn.CTCLoss(self.context.blank_label) + + params = list(filter(lambda x: x.requires_grad, self.model.parameters())) + optimizer = AdamW(params, self.training_cfg.learning_rate, weight_decay=self.training_cfg.weight_decay) + lr_scheduler = ExponentialLR(optimizer, gamma=self.training_cfg.lr_decay) + + Logger.info(f"Total number of parameters: {sum(p.numel() for p in params)}") + + train_dl = StandardAudioDataLoaderBuilder(self.train_dataset, collate_fn=audio_aug_comp).build( + self.training_cfg.batch_size + ) + + workspace.save_config(self.training_cfg) + writer.add_scalar("Meta/Parameters", sum(p.numel() for p in params)) + + pbar = trange(self.training_cfg.num_epochs, position=0, desc="Training", leave=True) + for epoch_idx in pbar: + self.model.train() + self.audio_transform.train() + self.model.streaming_state = None + total_loss = torch.Tensor([0.0]).to(self.device) + for batch in train_dl: + batch.to(self.device) + audio_length = self.audio_transform.compute_lengths(batch.lengths) + zmuv_audio_data = self.zmuv_transform(self.audio_transform(batch.audio_data)) + augmented_audio_data = spec_aug_comp(zmuv_audio_data) + if self.training_cfg.objective == "frame": + scores = self.model(augmented_audio_data, audio_length) + loss = criterion(scores, batch.labels) + else: + scores = self.model(augmented_audio_data, audio_length) + scores = F.log_softmax(scores, -1) # [num_frames x batch_size x num_labels] + audio_length = torch.tensor([self.model.compute_length(x.item()) for x in audio_length]).to( + self.device + ) + loss = criterion(scores, batch.labels, audio_length, batch.label_lengths) + optimizer.zero_grad() + self.model.zero_grad() + loss.backward() + optimizer.step() + with torch.no_grad(): + total_loss += loss + + lr_scheduler.step() + writer.add_scalar("Training/lr", lr_scheduler.get_last_lr()[0], epoch_idx) + + mean_loss = total_loss / len(train_dl) + pbar.set_postfix(dict(loss=f"{mean_loss.item():.3}")) + writer.add_scalar("Training/Loss", mean_loss.item(), epoch_idx) + + if epoch_idx % self.training_cfg.eval_frequency == 0 and epoch_idx != 0: + prefix = "Dev positive" + conf_matrix = self.evaluate_on_dataset(self.dev_pos_dataset, workspace, prefix, positive_set=True) + + writer.add_scalar(f"{prefix}/Metric/tp_rate", conf_matrix.tp / len(self.dev_pos_dataset), epoch_idx) + workspace.increment_model(self.model, conf_matrix.tp) + + Logger.heading("Model evaluation") + self.evaluate(workspace, evaluate_on_noisy_dataset=True) + + def evaluate(self, workspace: Workspace, evaluate_on_noisy_dataset: bool = False): + """Evaluate the model on every dev/test dataset""" + self.evaluate_on_dataset( + self.dev_pos_dataset, workspace, "Dev positive", positive_set=True, record_false_detections=True + ) + self.evaluate_on_dataset(self.dev_neg_dataset, workspace, "Dev negative", positive_set=False) + if evaluate_on_noisy_dataset: + dev_mixer = DatasetMixer(self.noise_datasets[DatasetSplit.DEV], seed=0, do_replace=False) + self.evaluate_on_dataset( + self.dev_pos_dataset, + workspace, + "Dev noisy positive", + positive_set=True, + mixer=dev_mixer, + record_false_detections=True, + ) + self.evaluate_on_dataset( + self.dev_neg_dataset, workspace, "Dev noisy negative", positive_set=False, mixer=dev_mixer + ) + self.evaluate_on_dataset( + self.test_pos_dataset, workspace, "Test positive", positive_set=True, record_false_detections=True + ) + self.evaluate_on_dataset(self.test_neg_dataset, workspace, "Test negative", positive_set=False) + if evaluate_on_noisy_dataset: + test_mixer = DatasetMixer(self.noise_datasets[DatasetSplit.TEST], seed=0, do_replace=False) + self.evaluate_on_dataset( + self.test_pos_dataset, + workspace, + "Test noisy positive", + positive_set=True, + mixer=test_mixer, + record_false_detections=True, + ) + self.evaluate_on_dataset( + self.test_neg_dataset, workspace, "Test noisy negative", positive_set=False, mixer=test_mixer, + ) + + def evaluate_on_dataset( + self, + dataset, + workspace: Workspace, + prefix: str, + positive_set: bool = False, + mixer: DatasetMixer = None, + record_false_detections: bool = False, + ): + """Evaluate the model on the given dataset""" + self.audio_transform.eval() + self.model.eval() + + conf_matrix = ConfusionMatrix() + pbar = tqdm(dataset, desc=prefix) + + for _, sample in enumerate(pbar): + if mixer is not None: + (sample,) = mixer([sample]) + audio_data = sample.audio_data.to(self.device) + self.inference_engine.reset() + seq_present = self.inference_engine.infer(audio_data) + if seq_present != positive_set and record_false_detections: + with (workspace.path / f"{prefix}_errors.tsv").open("a") as error_file: + error_file.write( + f"{sample.metadata.transcription}" + f"\t{int(seq_present)}" + f"\t{int(positive_set)}" + f"\t{sample.metadata.path}\n" + ) + conf_matrix.increment(seq_present, positive_set) + pbar.set_postfix(dict(mcc=f"{conf_matrix.mcc}", c=f"{conf_matrix}")) + + Logger.info(f"{conf_matrix}") + return conf_matrix diff --git a/howl/trainer_test.py b/howl/trainer_test.py index 8d874ea0..e2287f35 100644 --- a/howl/trainer_test.py +++ b/howl/trainer_test.py @@ -12,4 +12,5 @@ def test_trainer_instantiation(self): """Test instantiation of Trainer""" training_config_path = test_utils.test_data_path() / "test_training_config.json" training_cfg = TrainingConfig.parse_file(training_config_path) - Trainer(training_cfg) + trainer = Trainer(training_cfg) + trainer.train(debug=True) diff --git a/howl/workspace.py b/howl/workspace.py index e40d2849..97ca0fce 100644 --- a/howl/workspace.py +++ b/howl/workspace.py @@ -10,6 +10,7 @@ from howl.config import TrainingConfig from howl.settings import KEY_TO_SETTINGS_CLASS, SETTINGS, HowlSettings from howl.utils.dataclass import gather_dict +from howl.utils.logger import Logger @dataclass @@ -22,12 +23,19 @@ class Workspace: def __post_init__(self): """Initialize Workspace by creating the directory and summary writer""" + if self.path.exists(): + if self.delete_existing: + shutil.move(str(self.path), f"/tmp/{self.path.name}") + else: + Logger.warning(f"Workspace already exists: {self.path}") self.path.mkdir(parents=True, exist_ok=True) log_path = self.path / "logs" - if self.delete_existing: - shutil.rmtree(str(log_path), ignore_errors=True) self.summary_writer = SummaryWriter(str(log_path)) + def zmuv_model_path(self): + """Path of the trained zmuv .pt file""" + return str(self.path / "zmuv.pt.bin") + def model_path(self, best=False): """Path of the trained .pt file""" return str(self.path / f'model{"-best" if best else ""}.pt.bin') diff --git a/requirements.txt b/requirements.txt index 6cb7d1a8..df608972 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ coloredlogs dataclasses;python_version<"3.7" +devtools librosa>=0.8.0 numpy>=1.18.3 pandas>=1.0.3 diff --git a/test/test_data/test_training_config.json b/test/test_data/test_training_config.json index d93d798c..c3ecb4c6 100644 --- a/test/test_data/test_training_config.json +++ b/test/test_data/test_training_config.json @@ -4,28 +4,33 @@ }, "batch_size": 16, "learning_rate": 0.01, - "num_epochs": 10, + "num_epochs": 4, + "eval_frequency": 2, "lr_decay": 0.955, "weight_decay": 0.00001, - "use_noise_dataset": false, + "use_noise_dataset": true, "noise_datasets": [ { - "path": "/data/MS-SNSD" + "path": "~/data/kws/hey-ff/hey-ff-noise/MS-SNSD" } ], "train_datasets": [ { - "path": "/data/speaker-id-split-medium" + "path": "~/data/kws/hey-ff/hey-ff-data/speaker-id-split-medium" + }, + { + "path": "~/personal/howl/datasets/fire_fox/positive" } ], "val_datasets": [ { - "path": "/data/speaker-id-split-medium" + "path": "~/data/kws/hey-ff/hey-ff-data/speaker-id-split-medium" } ], "test_datasets": [ { - "path": "/data/speaker-id-split-medium" + "path": "~/data/kws/hey-ff/hey-ff-data/speaker-id-split-medium" } - ] + ], + "device": "cuda" } diff --git a/training/run/train_v2.py b/training/run/train_v2.py new file mode 100644 index 00000000..11adbed0 --- /dev/null +++ b/training/run/train_v2.py @@ -0,0 +1,10 @@ +from howl.config import TrainingConfig +from howl.trainer import Trainer +from howl.utils import test_utils + +training_config_path = test_utils.test_data_path() / "test_training_config.json" +training_cfg = TrainingConfig.parse_file(training_config_path) +trainer = Trainer(training_cfg) + +# trainer._prepare_dataset(DatasetSplit.TRAINING) +trainer.train(debug=True)