Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Change TF spectral ops to torchaudio #7

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,35 @@ path:
event_length: 1024
mel_length: 256
num_rows_per_batch: 12
split_frame_length: 256
dataset_is_deterministic: False
dataset_is_randomize_tokens: True
dataset_use_tf_spectral_ops: False

optim:
lr: 2e-4
warmup_steps: 64500
num_epochs: ${num_epochs}
num_steps_per_epoch: 1289 # TODO: this is not good practice. Ideally we can get this from dataloader.
min_lr: 1e-4

grad_accum: 1

dataloader:
train:
batch_size: 1
num_workers: 12
num_workers: 2
val:
batch_size: 1
num_workers: 12
num_workers: 0

modelcheckpoint:
monitor: 'val_loss'
mode: 'min'
save_last: True
save_top_k: 5
save_weights_only: False
every_n_epochs: 50
filename: '{epoch}-{step}-{val_loss:.4f}'

trainer:
Expand All @@ -43,6 +49,20 @@ trainer:
log_every_n_steps: 100
strategy: "ddp_find_unused_parameters_false"
devices: ${devices}
check_val_every_n_epoch: 10

eval:
is_sanity_check: False
eval_first_n_examples:
eval_after_num_epoch: 400
eval_per_epoch: 1
eval_dataset:
exp_tag_name:
audio_dir:
midi_dir:
contiguous_inference:
batch_size: 8
use_tf_spectral_ops: False # change this to True if using pretrained/mt3.pth

defaults:
- model: MT3Net
Expand Down
8 changes: 8 additions & 0 deletions config/dataset/Slakh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ train:
inst_filename: inst_names.json
audio_filename: mix_16k.wav
num_rows_per_batch: ${num_rows_per_batch}
split_frame_length: ${split_frame_length}
is_deterministic: ${dataset_is_deterministic}
is_randomize_tokens: ${dataset_is_randomize_tokens}
use_tf_spectral_ops: ${dataset_use_tf_spectral_ops}
val:
_target_: dataset.dataset_2_random.SlakhDataset # choosing which data class to use
root_dir: "/data2/kinwai/slakh2100_flac_redux/validation/"
Expand All @@ -16,6 +20,10 @@ val:
inst_filename: inst_names.json
audio_filename: mix_16k.wav
num_rows_per_batch: ${num_rows_per_batch}
split_frame_length: ${split_frame_length}
is_deterministic: ${dataset_is_deterministic}
is_randomize_tokens: ${dataset_is_randomize_tokens}
use_tf_spectral_ops: ${dataset_use_tf_spectral_ops}
test:
root_dir: "/data/slakh2100_flac_redux/test"
collate_fn: dataset.dataset_2_random.collate_fn
57 changes: 1 addition & 56 deletions contrib/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@

from contrib import event_codec, note_sequences, run_length_encoding

import note_seq
import numpy as np
import pretty_midi

S = TypeVar('S')
T = TypeVar('T')
Expand Down Expand Up @@ -143,57 +141,4 @@ def event_predictions_to_ns(
'est_ns': ns,
'est_invalid_events': total_invalid_events,
'est_dropped_events': total_dropped_events,
}


def get_prettymidi_pianoroll(ns: note_seq.NoteSequence, fps: float,
is_drum: bool):
"""Convert NoteSequence to pianoroll through pretty_midi."""
for note in ns.notes:
if is_drum or note.end_time - note.start_time < 0.05:
# Give all drum notes a fixed length, and all others a min length
note.end_time = note.start_time + 0.05

pm = note_seq.note_sequence_to_pretty_midi(ns)
end_time = pm.get_end_time()
cc = [
# all sound off
pretty_midi.ControlChange(number=120, value=0, time=end_time),
# all notes off
pretty_midi.ControlChange(number=123, value=0, time=end_time)
]
pm.instruments[0].control_changes = cc
if is_drum:
# If inst.is_drum is set, pretty_midi will return an all zero pianoroll.
for inst in pm.instruments:
inst.is_drum = False
pianoroll = pm.get_piano_roll(fs=fps)
return pianoroll


def frame_metrics(ref_pianoroll: np.ndarray,
est_pianoroll: np.ndarray,
velocity_threshold: int) -> Tuple[float, float, float]:
"""Frame Precision, Recall, and F1."""
import sklearn
# Pad to same length
if ref_pianoroll.shape[1] > est_pianoroll.shape[1]:
diff = ref_pianoroll.shape[1] - est_pianoroll.shape[1]
est_pianoroll = np.pad(
est_pianoroll, [(0, 0), (0, diff)], mode='constant')
elif est_pianoroll.shape[1] > ref_pianoroll.shape[1]:
diff = est_pianoroll.shape[1] - ref_pianoroll.shape[1]
ref_pianoroll = np.pad(
ref_pianoroll, [(0, 0), (0, diff)], mode='constant')

# For ref, remove any notes that are too quiet (consistent with Cerberus.)
ref_frames_bool = ref_pianoroll > velocity_threshold
# For est, keep all predicted notes.
est_frames_bool = est_pianoroll > 0

precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support(
ref_frames_bool.flatten(),
est_frames_bool.flatten(),
labels=[True, False])

return precision[0], recall[0], f1[0]
}
78 changes: 59 additions & 19 deletions contrib/spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
"""Audio spectrogram functions."""

import dataclasses
import torch
from torchaudio.transforms import MelSpectrogram
import librosa
import numpy as np

from ddsp import spectral_ops
import tensorflow as tf
# this is to suppress a warning from torch melspectrogram
import warnings
warnings.filterwarnings("ignore")

# defaults for spectrogram config
DEFAULT_SAMPLE_RATE = 16000
Expand All @@ -35,6 +40,7 @@ class SpectrogramConfig:
sample_rate: int = DEFAULT_SAMPLE_RATE
hop_width: int = DEFAULT_HOP_WIDTH
num_mel_bins: int = DEFAULT_NUM_MEL_BINS
use_tf_spectral_ops: bool = False

@property
def abbrev_str(self):
Expand All @@ -53,29 +59,63 @@ def frames_per_second(self):


def split_audio(samples, spectrogram_config):
"""Split audio into frames."""
return tf.signal.frame(
"""Split audio into frames using librosa."""
if samples.shape[0] % spectrogram_config.hop_width != 0:
samples = np.pad(
samples,
(0, spectrogram_config.hop_width - samples.shape[0] % spectrogram_config.hop_width),
'constant',
constant_values=0
)
return librosa.util.frame(
samples,
frame_length=spectrogram_config.hop_width,
frame_step=spectrogram_config.hop_width,
pad_end=True)


def compute_spectrogram(samples, spectrogram_config):
"""Compute a mel spectrogram."""
overlap = 1 - (spectrogram_config.hop_width / FFT_SIZE)
return spectral_ops.compute_logmel(
samples,
bins=spectrogram_config.num_mel_bins,
lo_hz=MEL_LO_HZ,
overlap=overlap,
fft_size=FFT_SIZE,
sample_rate=spectrogram_config.sample_rate)
hop_length=spectrogram_config.hop_width,
axis=-1).T


def compute_spectrogram(
samples,
spectrogram_config,
):
"""
Compute a mel spectrogram.
Due to multiprocessing issues running TF and PyTorch together, we use librosa
and only keep `spectral_ops.compute_logmel` for evaluation purposes.
"""
if spectrogram_config.use_tf_spectral_ops:
# NOTE: we only keep this for evaluating existing models
# This is because I find even with an equivalent PyTorch / librosa implementation
# that gives close-enough results (melspec MAE ~ 2e-3), the model output is still affected badly.
# lazy load
from ddsp import spectral_ops
overlap = 1 - (spectrogram_config.hop_width / FFT_SIZE)
return spectral_ops.compute_logmel(
samples,
bins=spectrogram_config.num_mel_bins,
lo_hz=MEL_LO_HZ,
overlap=overlap,
fft_size=FFT_SIZE,
sample_rate=spectrogram_config.sample_rate)
else:
transform = MelSpectrogram(
sample_rate=spectrogram_config.sample_rate,
n_fft=FFT_SIZE,
hop_length=spectrogram_config.hop_width,
n_mels=spectrogram_config.num_mel_bins,
f_min=MEL_LO_HZ,
power=1.0,
)
samples = torch.from_numpy(samples).float()
S = transform(samples)
S[S<0] = 0
S = torch.log(S + 1e-6)
return S.numpy().T


def flatten_frames(frames):
"""Convert frames back into a flat array of samples."""
return tf.reshape(frames, [-1])
return np.reshape(frames, (-1,))


def input_depth(spectrogram_config):
Expand Down
85 changes: 43 additions & 42 deletions contrib/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import note_seq
import seqio
import t5.data
import tensorflow as tf


DECODED_EOS_ID = -1
Expand Down Expand Up @@ -220,7 +219,7 @@ def _decode_id(encoded_id):
ids = [_decode_id(int(i)) for i in ids]
return ids

def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor:
def _encode_tf(self, token_ids):
"""Encode a list of tokens to a tf.Tensor.

Args:
Expand All @@ -229,46 +228,48 @@ def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor:
Returns:
a 1d tf.Tensor with dtype tf.int32
"""
with tf.control_dependencies(
[tf.debugging.assert_less(
token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)),
tf.debugging.assert_greater_equal(
token_ids, tf.cast(0, token_ids.dtype))
]):
tf_ids = token_ids + self._num_special_tokens
return tf_ids

def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
"""Decode in TensorFlow.

The special tokens of PAD and UNK as well as extra_ids will be
replaced with DECODED_INVALID_ID in the output. If EOS is present, it and
all following tokens in the decoded output and will be represented by
DECODED_EOS_ID.

Args:
ids: a 1d tf.Tensor with dtype tf.int32

Returns:
a 1d tf.Tensor with dtype tf.int32
"""
# Create a mask that is true from the first EOS position onward.
# First, create an array that is True whenever there is an EOS, then cumsum
# that array so that every position after and including the first True is
# >1, then cast back to bool for the final mask.
eos_and_after = tf.cumsum(
tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1)
eos_and_after = tf.cast(eos_and_after, tf.bool)

return tf.where(
eos_and_after,
DECODED_EOS_ID,
tf.where(
tf.logical_and(
tf.greater_equal(ids, self._num_special_tokens),
tf.less(ids, self._base_vocab_size)),
ids - self._num_special_tokens,
DECODED_INVALID_ID))
return None
# with tf.control_dependencies(
# [tf.debugging.assert_less(
# token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)),
# tf.debugging.assert_greater_equal(
# token_ids, tf.cast(0, token_ids.dtype))
# ]):
# tf_ids = token_ids + self._num_special_tokens
# return tf_ids

def _decode_tf(self, ids):
return None
# """Decode in TensorFlow.

# The special tokens of PAD and UNK as well as extra_ids will be
# replaced with DECODED_INVALID_ID in the output. If EOS is present, it and
# all following tokens in the decoded output and will be represented by
# DECODED_EOS_ID.

# Args:
# ids: a 1d tf.Tensor with dtype tf.int32

# Returns:
# a 1d tf.Tensor with dtype tf.int32
# """
# # Create a mask that is true from the first EOS position onward.
# # First, create an array that is True whenever there is an EOS, then cumsum
# # that array so that every position after and including the first True is
# # >1, then cast back to bool for the final mask.
# eos_and_after = tf.cumsum(
# tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1)
# eos_and_after = tf.cast(eos_and_after, tf.bool)

# return tf.where(
# eos_and_after,
# DECODED_EOS_ID,
# tf.where(
# tf.logical_and(
# tf.greater_equal(ids, self._num_special_tokens),
# tf.less(ids, self._base_vocab_size)),
# ids - self._num_special_tokens,
# DECODED_INVALID_ID))

def num_special_tokens(self):
return self._num_special_tokens
Expand Down
Loading