Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused RMSNorm incompatible with PP tracing (dynamic stride) #217

Open
wconstab opened this issue Apr 10, 2024 · 2 comments
Open

Fused RMSNorm incompatible with PP tracing (dynamic stride) #217

wconstab opened this issue Apr 10, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@wconstab
Copy link
Contributor

The incompatibility is that during backwards, fused_rmsnorm does dynamic control flow over strides, which isn't safe for export tracing used by PP.

        dy = dy.view(-1, dy.shape[-1])
        if dy.stride(-1) != 1:
            dy = dy.contiguous()

Which leads to a stacktrace ending in

    File "/data/users/whc/pytorch/torch/_dynamo/variables/tensor.py", line 326, in var_getattr
      unimplemented(f"Illegal getattr invocation {name} in strict mode")     
    File "/data/users/whc/pytorch/torch/_dynamo/exc.py", line 204, in unimplemented
      raise Unsupported(msg)                                                                                                                                                                      
  torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode

Would it be possible to refactor this in a more export friendly way, or is that difficult?

cc @lessw2020, @kwen2501

@lessw2020 lessw2020 self-assigned this Apr 10, 2024
@lessw2020
Copy link
Contributor

short term is the stride check can be removed to explore tracing (this check is rarely needed, confirmed on llama_7b).

Longer term this will either need a refactor to support dynamic strides (harder) or given the rarity, just a simple assert that we don't support non-contiguous.

@awgu
Copy link
Contributor

awgu commented Apr 11, 2024

I did not look into this closely, but could we rely on .contiguous() being a no-op if already contiguous and remove the stride check? (There might be ever-so-slightly more CPU overhead if there is a Python <> C++ switch from .contiguous(), but I think this should be okay for our purpose.)

@tianyu-l tianyu-l added the bug Something isn't working label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants