Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 58 additions & 14 deletions nemo/collections/speechlm2/models/salm_asr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,21 +493,65 @@ def generate(
# due to accuracy issues at bs>1
# audio_embeds, audio_embed_lens = self.perception(audios, audio_lens)
# audio_embeds = [audio_embeds[i, :elen] for i, elen in enumerate(audio_embed_lens)]
# audios is (B, T)
# audio_lens is (B,)

# chunk audio into 10-minute segments
chunk_size = 10 * 60 * self.sampling_rate

# Process each batch item separately to handle variable-length audios correctly
audio_embeds = []
for batch_idx in range(audios.shape[0]):
single_audio = audios[batch_idx : batch_idx + 1] # Keep batch dimension
single_audio_len = audio_lens[batch_idx : batch_idx + 1]

num_chunks = (single_audio_len.item() + chunk_size - 1) // chunk_size

# Accumulate chunks for this batch item
audio_chunks_embeds = []
transcript_chunks = []

for i in range(num_chunks):
chunk_start = i * chunk_size
chunk_end = min(chunk_start + chunk_size, single_audio.shape[1])

# Skip if this chunk is beyond the audio length
if chunk_start >= single_audio_len.item():
break

chunk_audio = single_audio[:, chunk_start:chunk_end]
chunk_audio_len = torch.clamp(single_audio_len - chunk_start, min=0, max=chunk_end - chunk_start)

# Process this chunk
encoded_chunk, encoded_len_chunk = self.perception.forward_encoder(
input_signal=chunk_audio, input_signal_length=chunk_audio_len
)
asr_hyps_chunk = self.perception.transcribe_encoded(
encoded=encoded_chunk, encoded_len=encoded_len_chunk
)
asr_tokens_chunk = [
torch.as_tensor(self.tokenizer.text_to_ids(f">> {hyp.text} <<" if hyp.text else ">> <<"))
for hyp in asr_hyps_chunk
]
Comment on lines +532 to +535
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

">>" and "<<" are applied at the very beginning and end of the ASR transcript.
For short audio, the final "audio" embedding looks like:

{audio_features}>> {asr_hyp} <<

I guess we can do the same for long audio:

{audio_features1}{audio_features2}>> {asr_hyp1} {asr_hyp2} <<

But here it seems ">>" and "<<" are added for each chunk instead

asr_tokens_len_chunk = [at.shape[0] for at in asr_tokens_chunk]
asr_tokens_chunk = torch.cat(asr_tokens_chunk, dim=0).unsqueeze(0).to(self.device)
transcript_embs_chunk = torch.split(
self.embed_tokens(asr_tokens_chunk).squeeze(0), asr_tokens_len_chunk, dim=0
)
audio_embeds_chunk, audio_embed_lens_chunk = self.perception(
encoded=encoded_chunk, encoded_len=encoded_len_chunk
)

# Accumulate audio embeddings (without transcript yet)
audio_chunks_embeds.append(audio_embeds_chunk[0][: audio_embed_lens_chunk[0]])
transcript_chunks.extend(transcript_embs_chunk)

# Concatenate all audio chunks, then append all transcripts at the end
full_audio_emb = torch.cat(audio_chunks_embeds, dim=0)
full_transcript_emb = torch.cat(transcript_chunks, dim=0)
combined_emb = torch.cat([full_audio_emb, full_transcript_emb], dim=0)
audio_embeds.append(combined_emb)

encoded, encoded_len = self.perception.forward_encoder(input_signal=audios, input_signal_length=audio_lens)
asr_hyps = self.perception.transcribe_encoded(encoded=encoded, encoded_len=encoded_len)
asr_tokens = [
torch.as_tensor(self.tokenizer.text_to_ids(f">> {hyp.text} <<" if hyp.text else ">> <<"))
for hyp in asr_hyps
]
asr_tokens_len = [at.shape[0] for at in asr_tokens]
asr_tokens = torch.cat(asr_tokens, dim=0).unsqueeze(0).to(self.device)
transcript_embs = torch.split(self.embed_tokens(asr_tokens).squeeze(0), asr_tokens_len, dim=0)
audio_embeds, audio_embed_lens = self.perception(encoded=encoded, encoded_len=encoded_len)
audio_embeds = [
torch.cat([aemb[:aemblen], temb], dim=0)
for aemb, aemblen, temb in zip(audio_embeds, audio_embed_lens, transcript_embs)
]
# Insert audio embeddings into relevant positions in text embeddings.
input_embeds, _, attention_mask = replace_placeholders_and_build_targets(
input_ids=tokens,
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/speechlm2/parts/hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def _from_pretrained(
# this setting skips loading the original pretrained ASR and LLM weights, and loads the
# final trained model weights directly.
model_kwargs['cfg']['pretrained_weights'] = False
# When loading from HF checkpoint, set init_from_path to the HF checkpoint directory
# so it can load the model weights from the correct location.
if 'init_from_path' in model_kwargs['cfg']:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_from_path is used to load a checkpoint from a previous stage.
In Stage 1, this argument is not set.
In Stage 2, it points to the final HF checkpoint from Stage 1.
When loading the final checkpoint after Stage 2, perhaps we do not need init_from_path, assuming the HF checkpoint saves everything.

model_kwargs['cfg']['init_from_path'] = model_id

return super()._from_pretrained(
model_id=model_id,
revision=revision,
Expand Down
Loading