-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SlimIPL -Iterative pseudo labeling #9193
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <152167970+nune-tadevosyan@users.noreply.github.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
@@ -36,6 +41,7 @@ | |||
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig | |||
from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler | |||
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType | |||
from nemo.collections.asr.parts.utils.ipl_utils import * |
Check notice
Code scanning / CodeQL
'import *' may pollute namespace Note
nemo.collections.asr.parts.utils.ipl_utils
from nemo.collections.asr.losses.ctc import CTCLoss | ||
from nemo.collections.asr.metrics.wer import WER | ||
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel | ||
from nemo.collections.asr.parts.mixins import ASRBPEMixin, InterCTCMixin, TranscribeConfig | ||
from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType | ||
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig | ||
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType | ||
from nemo.collections.asr.parts.utils.ipl_utils import * |
Check notice
Code scanning / CodeQL
'import *' may pollute namespace Note
nemo.collections.asr.parts.utils.ipl_utils
Signed-off-by: nune-tadevosyan <152167970+nune-tadevosyan@users.noreply.github.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My initial review related to lhotse dataloading part. Cool work!
@@ -138,6 +138,7 @@ def get_lhotse_dataloader_from_config( | |||
global_rank: int, | |||
world_size: int, | |||
dataset: torch.utils.data.Dataset, | |||
pseudo_label_gen: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading the entirety of lhotse dataloading changes, I suggest the following:
- let's remove the
pseudo_label_gen
flag, renamerandom_access
totarred_random_access
and usetarred_random_access
as the only flag to enable the new behavior - remove
pseudo_label_gen
argument from every function signature it is added, and instead read this option fromconfig.tarred_random_access
Basically, I don't want the dataloader code to "know" about pseudo-labeling, instead I'd like it to have an option to run dataloading in the random access mode that is required by pseudo-labeling.
I will follow up with more detailed changes in the next comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -331,6 +332,17 @@ def get_lhotse_dataloader_from_config( | |||
# We use lhotse's own worker_init_fn which leverages information such as rank, world_size, | |||
# worker_id, etc. to set a different random seed for each (node, worker) combination. | |||
# This together with infinite datasets removes the need to split data across nodes/workers. | |||
if pseudo_label_gen: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this whole change and instead change the condition above to if is_tarred and not config.tarred_random_access
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " | ||
f"Conflicting audio file names are JSON='{data['audio_filepath']}' and TAR='{tar_info.name}'" | ||
) | ||
if self.random_access: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, let's wait for PR #9187 to be merged; then after resolving conflicts, I suggest we factor the sequential/random iteration into separate generator functions like this:
def _iter_random_read(self, shard_ids: Sequence[int]) -> Generator[tuple[dict, bytes], None, None]:
for sid in shard_ids:
shard_manifest = self.shard_id_to_manifest[sid]
tar_path = self.shard_id_to_tar_path[sid]
with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r") as tar:
for data in shard_manifest:
try:
tar_info = tar.getmember(data["audio_filepath"])
raw_audio = tar.extractfile(tar_info).read()
yield data, raw_audio
except KeyError as e:
manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]
raise RuntimeError(
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"The following audio_filepath='{data['audio_filepath']}' was not found in the tar file."
) from e
def _iter_sequential(self, shard_ids: Sequence[int]) -> Generator[tuple[dict, bytes], None, None]:
for sid in shard_ids:
shard_manifest = self.shard_id_to_manifest[sid]
tar_path = self.shard_id_to_tar_path[sid]
with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar:
for data, tar_info in zip(shard_manifest, tar):
manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]
assert data["audio_filepath"] == tar_info.name, (
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"Conflicting audio file names are JSON='{data['audio_filepath']}' and TAR='{tar_info.name}'"
)
raw_audio = tar.extractfile(tar_info).read()
yield data, raw_audio
(we'll need to update this bit of code to have the changes from #9187, as I had written it before that PR was submitted)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
if is_tarred and not metadata_only: | ||
if pseudo_label_gen: | ||
if not config.tarred_random_access: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we flip the if
here? generally it's better to put the more likely case in if
and the less likely case in else
(except for early returns)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flipped the condition.
@@ -21,6 +21,8 @@ | |||
from typing import Generator, Iterable, List, Literal | |||
|
|||
import soundfile | |||
|
|||
# import Sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
) | ||
cut.custom = _to_custom_attr_dict(data) | ||
yield cut | ||
if self.tarred_random_access: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest
iter_fn = self._iter_random_read if self.tarred_random_access else self._iter_sequential
for data, raw_audio, tar_info in iter_fn(shard_ids):
meta = soundfile.info(BytesIO(raw_audio))
recording = Recording(
id=tar_info.path,
sources=[AudioSource(type="memory", channels=list(range(meta.channels)), source=raw_audio)],
sampling_rate=int(meta.samplerate),
num_samples=meta.frames,
duration=meta.duration,
)
cut = recording.to_cut()
cut.supervisions.append(
SupervisionSegment(
id=cut.id,
recording_id=cut.recording_id,
start=0,
duration=cut.duration,
text=data.get(self.text_field),
language=data.get(self.lang_field),
)
)
cut.custom = _to_custom_attr_dict(data)
return cut
(since create_recording
is unlikely to be re-used; but if you still want to leave it I'd at least rename to create_cut
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to suggested code.
@@ -57,7 +57,7 @@ class LhotseDataLoadingConfig: | |||
# b. Lhotse CutSet manifest / Lhotse Shar tar dir paths. | |||
cuts_path: str | None = None | |||
shar_path: Any = None # str | list[str | tuple[str, float | int]] | None = None | |||
|
|||
tarred_random_access: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some documentation for this option? E.g. a comment here saying sth like # Enable this to support dataloading from JSON manifests that reference subsets of audio tar files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment.
Also will document this in NeMo documentation
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll review the rest in some time, but the core change cannot be done. Please think of another way to implement this
optimizer=self._optimizer, | ||
scheduler_config=scheduler_config, | ||
train_dataloader=self._train_dl, | ||
ipl_config=self.cfg.get("ipl", None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry but this cannot be added to the core. The core does not make any assumptions for it's domains. You'll need to think of a different way to implement this part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do it in a subclass of ASRModel, so that this is specific to that subclass
What does this PR do ?
Implementation of iterative pseudo-labeling algorithm slimIPL. Algorithm is available for CTC and Hybrid models. It supports tarr, non-tarr, lhotse datasets. Also works with data from aistore.
Collection: [Note which collection this PR will affect]
Changelog
Usage
To do iterative pseudo labeling new field "ipl" should be given in the config file with following parameters.
Models for which this will work.
For usage our training scripts can be used with proper config file.
Note
If IPL part started and the training stopped for some reason, to continue from the same place in config the following parameters should be given m_epochs = -1 and n_l_updates = 0, otherwise change them correspondingly.
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information