Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SlimIPL -Iterative pseudo labeling #9193

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

nune-tadevosyan
Copy link
Collaborator

@nune-tadevosyan nune-tadevosyan commented May 14, 2024

What does this PR do ?

Implementation of iterative pseudo-labeling algorithm slimIPL. Algorithm is available for CTC and Hybrid models. It supports tarr, non-tarr, lhotse datasets. Also works with data from aistore.

Collection: [Note which collection this PR will affect]

Changelog

  1. Changed Hybrid and CTC models to provide iterative pseudo label generation.
  2. Changed dataset classes to do caching of audio files from aistore only in the beginning of training.
  3. Added support for datasets which audio files are in aistore and manifests are local(on clusters).
  4. Added dropout change functionality for conformer encoder.
  5. Added random access for tarred lhotse datasets.
  6. Added new learning rate scheduler for slimIPL
  7. Changed some functions in NeMo core for integration newly pseudo-labeled dataset during training.

Usage

To do iterative pseudo labeling new field "ipl" should be given in the config file with following parameters.

  ipl:
    m_epochs: 0  - how many epochs to train model before first PL generation.
    restore_pc: false - restore pc by comparing with already existing transcriptions if there are any
    manifest_filepath: /path/to/manifest 
    tarred_audio_filepaths:  /path/to/tarred/
    is_tarred: false - if the datasets are tarred
    dataset_weights: 1  - what part of the dataset to use (applicable with non-tar datasets)
    limit_train_batches: 170  - to what to change limit_train_batches after PLs are added to train set (for lhotse only)
    cache_manifest: /path/for/cache/manifest - optional 
    dropout: 0.1 - to what to change dropout after first PL generation
    n_l_epochs: 0 - how many epochs to train with changed dropout before adding PLs to train set
    p_cache: 0.2  - probability  with which cache will be updated
    cache_prefix:  - prefix for cache manifests files (optional for non-tar datasets)
    batch_size: 128  - batch size with which PLs will be generated

Models for which this will work.

  • EncDecHybridRNNTCTCModel
  • EncDecHybridRNNTCTCBPEModel
  • EncDecCTCModel
  • EncDecCTCModelBPE

For usage our training scripts can be used with proper config file.

python NeMo/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py  --config-path=/path/to/config/ --config-name='config.yaml'. 

Note
If IPL part started and the training stopped for some reason, to continue from the same place in config the following parameters should be given m_epochs = -1 and n_l_updates = 0, otherwise change them correspondingly.

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
@github-actions github-actions bot added core Changes to NeMo Core ASR common labels May 14, 2024
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
@github-actions github-actions bot added the CI label May 14, 2024
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
@github-actions github-actions bot removed the CI label May 14, 2024
nune-tadevosyan and others added 2 commits May 14, 2024 18:21
Signed-off-by: nune-tadevosyan <152167970+nune-tadevosyan@users.noreply.github.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
nemo/collections/asr/models/ctc_models.py Fixed Show fixed Hide fixed
nemo/collections/asr/models/ctc_models.py Fixed Show fixed Hide fixed
nemo/collections/asr/models/ctc_models.py Fixed Show fixed Hide fixed
@@ -36,6 +41,7 @@
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.collections.asr.parts.utils.ipl_utils import *

Check notice

Code scanning / CodeQL

'import *' may pollute namespace Note

Import pollutes the enclosing namespace, as the imported module
nemo.collections.asr.parts.utils.ipl_utils
does not define '__all__'.
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin, InterCTCMixin, TranscribeConfig
from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.collections.asr.parts.utils.ipl_utils import *

Check notice

Code scanning / CodeQL

'import *' may pollute namespace Note

Import pollutes the enclosing namespace, as the imported module
nemo.collections.asr.parts.utils.ipl_utils
does not define '__all__'.
nemo/collections/asr/models/ctc_models.py Fixed Show fixed Hide fixed
nune-tadevosyan and others added 5 commits May 14, 2024 19:30
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <152167970+nune-tadevosyan@users.noreply.github.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
Copy link
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial review related to lhotse dataloading part. Cool work!

@@ -138,6 +138,7 @@ def get_lhotse_dataloader_from_config(
global_rank: int,
world_size: int,
dataset: torch.utils.data.Dataset,
pseudo_label_gen: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the entirety of lhotse dataloading changes, I suggest the following:

  • let's remove the pseudo_label_gen flag, rename random_access to tarred_random_access and use tarred_random_access as the only flag to enable the new behavior
  • remove pseudo_label_gen argument from every function signature it is added, and instead read this option from config.tarred_random_access

Basically, I don't want the dataloader code to "know" about pseudo-labeling, instead I'd like it to have an option to run dataloading in the random access mode that is required by pseudo-labeling.

I will follow up with more detailed changes in the next comments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -331,6 +332,17 @@ def get_lhotse_dataloader_from_config(
# We use lhotse's own worker_init_fn which leverages information such as rank, world_size,
# worker_id, etc. to set a different random seed for each (node, worker) combination.
# This together with infinite datasets removes the need to split data across nodes/workers.
if pseudo_label_gen:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this whole change and instead change the condition above to if is_tarred and not config.tarred_random_access

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"Conflicting audio file names are JSON='{data['audio_filepath']}' and TAR='{tar_info.name}'"
)
if self.random_access:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, let's wait for PR #9187 to be merged; then after resolving conflicts, I suggest we factor the sequential/random iteration into separate generator functions like this:

    def _iter_random_read(self, shard_ids: Sequence[int]) -> Generator[tuple[dict, bytes], None, None]:
        for sid in shard_ids:
            shard_manifest = self.shard_id_to_manifest[sid]
            tar_path = self.shard_id_to_tar_path[sid]
            with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r") as tar:
                for data in shard_manifest:
                    try:
                        tar_info = tar.getmember(data["audio_filepath"])
                        raw_audio = tar.extractfile(tar_info).read()
                        yield data, raw_audio
                    except KeyError as e:
                        manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]
                        raise RuntimeError(
                            f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
                            f"The following audio_filepath='{data['audio_filepath']}' was not found in the tar file."
                        ) from e

    def _iter_sequential(self, shard_ids: Sequence[int]) -> Generator[tuple[dict, bytes], None, None]:
        for sid in shard_ids:
            shard_manifest = self.shard_id_to_manifest[sid]
            tar_path = self.shard_id_to_tar_path[sid]
            with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar:
                for data, tar_info in zip(shard_manifest, tar):
                    manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]
                    assert data["audio_filepath"] == tar_info.name, (
                        f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
                        f"Conflicting audio file names are JSON='{data['audio_filepath']}' and TAR='{tar_info.name}'"
                    )
                    raw_audio = tar.extractfile(tar_info).read()
                    yield data, raw_audio

(we'll need to update this bit of code to have the changes from #9187, as I had written it before that PR was submitted)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

nune-tadevosyan and others added 4 commits May 20, 2024 17:08
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
if is_tarred and not metadata_only:
if pseudo_label_gen:
if not config.tarred_random_access:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we flip the if here? generally it's better to put the more likely case in if and the less likely case in else (except for early returns)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flipped the condition.

@@ -21,6 +21,8 @@
from typing import Generator, Iterable, List, Literal

import soundfile

# import Sequence
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)
cut.custom = _to_custom_attr_dict(data)
yield cut
if self.tarred_random_access:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest

iter_fn = self._iter_random_read if self.tarred_random_access else self._iter_sequential

for data, raw_audio, tar_info in iter_fn(shard_ids):
    meta = soundfile.info(BytesIO(raw_audio))
        recording = Recording(
            id=tar_info.path,
            sources=[AudioSource(type="memory", channels=list(range(meta.channels)), source=raw_audio)],
            sampling_rate=int(meta.samplerate),
            num_samples=meta.frames,
            duration=meta.duration,
        )
        cut = recording.to_cut()
        cut.supervisions.append(
            SupervisionSegment(
                id=cut.id,
                recording_id=cut.recording_id,
                start=0,
                duration=cut.duration,
                text=data.get(self.text_field),
                language=data.get(self.lang_field),
            )
        )
        cut.custom = _to_custom_attr_dict(data)
        return cut

(since create_recording is unlikely to be re-used; but if you still want to leave it I'd at least rename to create_cut).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to suggested code.

@@ -57,7 +57,7 @@ class LhotseDataLoadingConfig:
# b. Lhotse CutSet manifest / Lhotse Shar tar dir paths.
cuts_path: str | None = None
shar_path: Any = None # str | list[str | tuple[str, float | int]] | None = None

tarred_random_access: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some documentation for this option? E.g. a comment here saying sth like # Enable this to support dataloading from JSON manifests that reference subsets of audio tar files

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment.
Also will document this in NeMo documentation

Signed-off-by: nune-tadevosyan <ntadevosyan@nvidia.com>
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
@nune-tadevosyan nune-tadevosyan marked this pull request as ready for review May 27, 2024 08:41
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll review the rest in some time, but the core change cannot be done. Please think of another way to implement this

optimizer=self._optimizer,
scheduler_config=scheduler_config,
train_dataloader=self._train_dl,
ipl_config=self.cfg.get("ipl", None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but this cannot be added to the core. The core does not make any assumptions for it's domains. You'll need to think of a different way to implement this part.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do it in a subclass of ASRModel, so that this is specific to that subclass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ASR common core Changes to NeMo Core
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants