-
Notifications
You must be signed in to change notification settings - Fork 111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fused_rmsnorm] Register as a custom operator for tracing #303
base: gh/wconstab/10/base
Are you sure you want to change the base?
Conversation
This just refactors the fused_rmsnorm kernel into torch_library functions so export tracing can avoid tracing inside the kernel which has several tracing-unfriendly things including dynamic stride usage ghstack-source-id: 17480f673cea1938b485d6d9f736ef4f845d8d98 Pull Request resolved: #303
torchtitan/models/norms.py
Outdated
x_shape_start = x.shape | ||
# Make fused_rmsnorm a custom op, to work around tracing issues for pp tracer/export | ||
FUSED_RMSNORM_FORWARD = "torchtitan::fused_rmsnorm_forward" | ||
FUSED_RMSNORM_BACKWARD = "torchtitan::fused_rmsnorm_forward" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be torchtitan::fused_rmsnorm_backward
?
This just refactors the fused_rmsnorm kernel into torch_library functions so export tracing can avoid tracing inside the kernel which has several tracing-unfriendly things including dynamic stride usage ghstack-source-id: ebb1f1a95fe041f8eae05f5d00f644249aa126aa Pull Request resolved: #303
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks.
If this works, I'd love to make a copy of it to put in pippy as a unit test.
It IMA's |
What is "IMA" short for? |
Illegal Memory Access - the generic cuda error that something has exceeded it's memory index. |
Thanks @lessw2020 . |
Hi @kwen2501 - I'm debugging into this. It's not an issue with the kernel per se. Rather, for some reason when the kernel is run as a registered op then the triton masking is being randomly polluted with values that exceed the CUDA memory addressable space and this causes the IMA.
The above should be all 0's b/c it's where we have no input data...but somehow this ginormous number and others like it are randomly being added into the masked off input values. This is what causes the IMA. For reference, a normal value for inputs (where we have values):
Will continue investigating but this issue is likely some kind of bug between triton load masking and what is going awry when run as a custom op and not a kernel specific issue. |
Thanks @lessw2020 for the demonstration.
Can you point me to the code where triton load masking is done? Also cc @tugsbayasgalan @zou3519 |
Hi @kwen2501 - sure, here's the specific line that has the issue. torchtitan/torchtitan/models/norms.py Line 198 in f72a2a0
That is loading the inputs and masking off any values past the known col length and should set those to zero. However, some of those zeros are being randomly polluted. |
torch.library.define( | ||
FUSED_RMSNORM_FORWARD, "(Tensor x, Tensor weight, float eps) -> (Tensor, Tensor)" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the IMA still a problem? If so:
Can you try the following:
torch.library.define(
FUSED_RMSNORM_FORWARD, "(Tensor x, Tensor weight, float eps) -> (Tensor, Tensor)",
tags=[torch.Tag.needs_fixed_stride_order])
)
Do you know where the IMA is happening? that is, does it happen with torch.compile(backend="eager"), torch.compile(backend="aot_eager"), and/or torch.compile(backend="inductor")?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zou3519 - thanks for the update.
I ran with that modified registered op as above, but hit's an IMA:
Ima:
rank0]:2024-05-20 12:51:50,341 - root - INFO - Training starts at step 1
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]: File "/data/users/less/local/torchtitan_oss/train.py", line 420, in <module>
[rank0]:[rank0]: main(config)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
[rank0]:[rank0]: return f(*args, **kwargs)
[rank0]:[rank0]: File "/data/users/less/local/torchtitan_oss/train.py", line 309, in main
[rank0]:[rank0]: loss.backward()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/_tensor.py", line 523, in backward
[rank0]:[rank0]: torch.autograd.backward(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:[rank0]: _engine_run_backward(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/autograd/graph.py", line 1201, in _engine_run_backward
[rank0]:[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/autograd/function.py", line 302, in apply
[rank0]:[rank0]: return user_fn(self, *args)
[rank0]:[rank0]: File "/data/users/less/local/torchtitan_oss/torchtitan/models/norms.py", line 360, in backward
[rank0]:[rank0]: return torch.ops.torchtitan.fused_rmsnorm_backward.default(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/gocuda/lib/python3.10/site-packages/torch/_ops.py", line 630, in __call__
[rank0]:[rank0]: return self_._op(*args, **kwargs)
[rank0]:[rank0]: File "/data/users/less/local/torchtitan_oss/torchtitan/models/norms.py", line 328, in fused_rmsnorm_backward
[rank0]:[rank0]: dw = _dw.sum(0).to(weight.dtype)
[rank0]:[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I'm not sure what you mean regarding torch.compile - this is already a fused kernel so we are running it directly without using torch.compile.
Running regular RMSNorm with torch.compile results in completely different issues (creation of multiple weights or FX graph tracing failures).
I have made a cuda RMSNorm that I will test out now as a possible replacement for this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I thought we were encountering the IMA when using this custom op with torch.compile. It is the case that the custom op IMAs in eager mode PyTorch?
Stack from ghstack (oldest at bottom):
This just refactors the fused_rmsnorm kernel into torch_library
functions so export tracing can avoid tracing inside the kernel which
has several tracing-unfriendly things including dynamic stride usage