Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test failures with new configs in test_grad_scaling_autocast in test_torch.py #126638

Open
gambiTarun opened this issue May 19, 2024 · 0 comments
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@gambiTarun
Copy link
Contributor

gambiTarun commented May 19, 2024

This issue tracks my observations when updating test_grad_scaling_autocast in test_torch.py with the new OptimizerInfo infrastructure (#123451).

While I was able to combine tests that call _grad_scaling_autocast_test into one test (#125538), I observe test failures when I try to use _get_optim_inputs_including_global_cliquey_kwargs to avoid hardcoded configs.

The following is the test case:

@onlyNativeDeviceTypes
@optims(
    [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
    dtypes=[torch.float32]
)
def test_grad_scaling_autocast(self, device, dtype, optim_info):
    try_pickle = False

    def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
        for i, (input, target) in enumerate(data):
            optimizer.zero_grad()
            with torch.autocast(device_type=device, dtype=torch.half, enabled=try_scaling_api):
                output = model(input)
                loss = loss_fn(output, target)
            if try_scaling_api:
                scaler.scale(loss).backward()
                if i == skip_iter and scaler.is_enabled():
                    with torch.no_grad():
                        model[1].weight.grad.fill_(float('inf'))
                scaler.step(optimizer)
                scaler.update()
                if try_pickle:
                    scaler = pickle.loads(pickle.dumps(scaler))
            else:
                loss.backward()
                if (not scaler.is_enabled()) or (i != skip_iter):
                    optimizer.step()
        return scaler

    optimizer_ctor = optim_info.optim_cls
    all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
        device, dtype, optim_info, skip=("differentiable",))
    # Compares no scaling + no autocasting against scaling + autocasting.
    for optim_input in all_optim_inputs:
        
        # NOTE(mkozuki): With current way of testing, `torch.optim.Adam` is failing in spite of `foreach` and `fused`.
        #   Giving some flexibility to this test might help.
        context = contextlib.nullcontext
        if optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
            from functools import partial
            context = partial(self.assertRaises, AssertionError)
        with context():
            # sets atol=1e-3 because we're comparing pure fp32 arithmetic vs a mixture of fp16 and fp32
            self._run_scaling_case(
                device, run, unskipped=3, skipped=1, atol=1e-3,
                optimizer_ctor=optimizer_ctor, optimizer_kwargs=optim_input.kwargs,
            )
            # this will be picked up by try_pickle within run():
            try_pickle = True
            self._run_scaling_case(
                device, run, unskipped=3, skipped=1, atol=1e-3,
                optimizer_ctor=optimizer_ctor, optimizer_kwargs=optim_input.kwargs,
            )

The following observations I made about the failing configs generated from _get_optim_inputs_including_global_cliquey_kwargs:

  1. When optimizer_ctor is SGD, the test fails for the config {'weight_decay': 0.1, 'maximize': True, 'fused': True}.
  2. When the context is partial(self.assertRaises, AssertionError) for Adam and AdamW, the tests fail for configs {'lr': 0.01, 'fused': False}, {'lr': 0.01, 'fused': True} with the error AssertionError: AssertionError not raised.
  3. When I change the context to contextlib.nullcontext for Adam and AdamW (since I notice the AssertionError is not raised in observation 2), the tests fail for all the configs with the error AssertionError: Tensor-likes are not close!. In this case, I am confused as to why is the error being thrown even for the configs that failed in observation 2, the mismatch elements percentage is around 3.1% for {'lr': 0.01, 'fused': False}, {'lr': 0.01, 'fused': True} but either 39.1% or 100% for other configs.

Please let me know if I can provide any additional information or perform any other tests. I would be happy to work on this.

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar

@janeyx99 janeyx99 added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants