Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fused_rmsnorm] Register as a custom operator for tracing #303

Open
wants to merge 2 commits into
base: gh/wconstab/10/base
Choose a base branch
from

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented May 3, 2024

Stack from ghstack (oldest at bottom):

This just refactors the fused_rmsnorm kernel into torch_library
functions so export tracing can avoid tracing inside the kernel which
has several tracing-unfriendly things including dynamic stride usage

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 3, 2024
wconstab added a commit that referenced this pull request May 3, 2024
This just refactors the fused_rmsnorm kernel into torch_library
functions so export tracing can avoid tracing inside the kernel which
has several tracing-unfriendly things including dynamic stride usage

ghstack-source-id: 17480f673cea1938b485d6d9f736ef4f845d8d98
Pull Request resolved: #303
x_shape_start = x.shape
# Make fused_rmsnorm a custom op, to work around tracing issues for pp tracer/export
FUSED_RMSNORM_FORWARD = "torchtitan::fused_rmsnorm_forward"
FUSED_RMSNORM_BACKWARD = "torchtitan::fused_rmsnorm_forward"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be torchtitan::fused_rmsnorm_backward?

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 3, 2024
This just refactors the fused_rmsnorm kernel into torch_library
functions so export tracing can avoid tracing inside the kernel which
has several tracing-unfriendly things including dynamic stride usage

ghstack-source-id: ebb1f1a95fe041f8eae05f5d00f644249aa126aa
Pull Request resolved: #303
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks.
If this works, I'd love to make a copy of it to put in pippy as a unit test.

@wconstab
Copy link
Contributor Author

wconstab commented May 4, 2024

It IMA's
@tianyu-l also had a Pr for this but I didn't know about it :/ hopefully figure out soon.

@kwen2501
Copy link
Contributor

kwen2501 commented May 4, 2024

What is "IMA" short for?

@lessw2020
Copy link
Contributor

What is "IMA" short for?

Illegal Memory Access - the generic cuda error that something has exceeded it's memory index.

@kwen2501
Copy link
Contributor

kwen2501 commented May 7, 2024

Thanks @lessw2020 .
Do you think the IMA relates to the triton kernel? Can you help fix it? PP needs this fix to land. Would appreciate your help.

@lessw2020
Copy link
Contributor

Thanks @lessw2020 . Do you think the IMA relates to the triton kernel? Can you help fix it? PP needs this fix to land. Would appreciate your help.

Hi @kwen2501 - I'm debugging into this. It's not an issue with the kernel per se. Rather, for some reason when the kernel is run as a registered op then the triton masking is being randomly polluted with values that exceed the CUDA memory addressable space and this causes the IMA.

[rank0]:pid (18, 0, 0) idx (150) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (151) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (152) x_hat: -174770156674672237865863089087886393344.000000
[rank0]:pid (18, 0, 0) idx (153) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (154) x_hat: 0.000000

The above should be all 0's b/c it's where we have no input data...but somehow this ginormous number and others like it are randomly being added into the masked off input values. This is what causes the IMA.

For reference, a normal value for inputs (where we have values):

[rank0]:pid (2, 0, 0) idx ( 36) x: 0.673390
[rank0]:pid (2, 0, 0) idx ( 37) x: 0.314899
[rank0]:pid (2, 0, 0) idx ( 38) x: 0.522899
[rank0]:pid (2, 0, 0) idx ( 39) x: -0.126250
[rank0]:pid (2, 0, 0) idx ( 40) x: -0.819483

Will continue investigating but this issue is likely some kind of bug between triton load masking and what is going awry when run as a custom op and not a kernel specific issue.

@kwen2501
Copy link
Contributor

kwen2501 commented May 7, 2024

Thanks @lessw2020 for the demonstration.

some kind of bug between triton load masking and what is going awry when run as a custom op

Can you point me to the code where triton load masking is done?

Also cc @tugsbayasgalan @zou3519

@lessw2020
Copy link
Contributor

Hi @kwen2501 - sure, here's the specific line that has the issue.

x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)

That is loading the inputs and masking off any values past the known col length and should set those to zero. However, some of those zeros are being randomly polluted.

Comment on lines +219 to +221
torch.library.define(
FUSED_RMSNORM_FORWARD, "(Tensor x, Tensor weight, float eps) -> (Tensor, Tensor)"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the IMA still a problem? If so:

Can you try the following:

torch.library.define(
    FUSED_RMSNORM_FORWARD, "(Tensor x, Tensor weight, float eps) -> (Tensor, Tensor)",
    tags=[torch.Tag.needs_fixed_stride_order])
)

Do you know where the IMA is happening? that is, does it happen with torch.compile(backend="eager"), torch.compile(backend="aot_eager"), and/or torch.compile(backend="inductor")?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zou3519 - thanks for the update.
I ran with that modified registered op as above, but hit's an IMA:

Ima:
rank0]:2024-05-20 12:51:50,341 - root - INFO - Training starts at step 1
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "/data/users/less/local/torchtitan_oss/train.py", line 420, in <module>
[rank0]:[rank0]:     main(config)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
[rank0]:[rank0]:     return f(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/less/local/torchtitan_oss/train.py", line 309, in main
[rank0]:[rank0]:     loss.backward()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/_tensor.py", line 523, in backward
[rank0]:[rank0]:     torch.autograd.backward(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:[rank0]:     _engine_run_backward(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/autograd/graph.py", line 1201, in _engine_run_backward
[rank0]:[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/autograd/function.py", line 302, in apply
[rank0]:[rank0]:     return user_fn(self, *args)
[rank0]:[rank0]:   File "/data/users/less/local/torchtitan_oss/torchtitan/models/norms.py", line 360, in backward
[rank0]:[rank0]:     return torch.ops.torchtitan.fused_rmsnorm_backward.default(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/_ops.py", line 630, in __call__
[rank0]:[rank0]:     return self_._op(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/less/local/torchtitan_oss/torchtitan/models/norms.py", line 328, in fused_rmsnorm_backward
[rank0]:[rank0]:     dw = _dw.sum(0).to(weight.dtype)
[rank0]:[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I'm not sure what you mean regarding torch.compile - this is already a fused kernel so we are running it directly without using torch.compile.
Running regular RMSNorm with torch.compile results in completely different issues (creation of multiple weights or FX graph tracing failures).
I have made a cuda RMSNorm that I will test out now as a possible replacement for this issue.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I thought we were encountering the IMA when using this custom op with torch.compile. It is the case that the custom op IMAs in eager mode PyTorch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants