You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you very much for your amazing work!
While using melody-conditioned generation, I encountered the following error in the DataCollatorMusicGenWithPadding class:
input_values here is actually a dictionary. I resolved the error by changing the code to batch.update(input_values). Could you kindly confirm if this approach is correct?
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
labels = [
torch.tensor(feature["labels"]).transpose(0, 1) for feature in features
]
# (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=-100
)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.processor.tokenizer.pad(input_ids, return_tensors="pt")
batch = {"labels": labels, **input_ids}
if self.feature_extractor_input_name in features[0]:
input_values = [
{
self.feature_extractor_input_name: feature[
self.feature_extractor_input_name
]
}
for feature in features
]
input_values = self.processor.feature_extractor.pad(
input_values, return_tensors="pt"
)
batch.update(input_values)
# batch[self.feature_extractor_input_name : input_values]
return batch
The text was updated successfully, but these errors were encountered:
Thank you very much for your amazing work!
While using melody-conditioned generation, I encountered the following error in the
DataCollatorMusicGenWithPadding
class:input_values
here is actually a dictionary. I resolved the error by changing the code tobatch.update(input_values)
. Could you kindly confirm if this approach is correct?The text was updated successfully, but these errors were encountered: