Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance of sequence parallel fused kernels in real model training #1036

Closed
Fanoid opened this issue Apr 26, 2024 · 6 comments
Closed

Comments

@Fanoid
Copy link

Fanoid commented Apr 26, 2024

❓ Questions and Help

I just found fused kernels in sequence parallel got poor performance in real model training.

Here is a snapshot of nsys timeline of a ColumnParallelLinear forward call.
The total sequence length is 8k, the hidden size is 8k, and the tensor parallel size if 8. Triton kernel is disabled, so multi-stream implementation is used.

image

In the CPU execution of Python code, there is a large margin (about 600 us) at the beginning and also many bubbles (about 100 us) between CUDA API calls.
It makes the total elapsed time of CPU code even longer then the corresponding CUDA kernels/copies, which reaches 2.9 ms.

I think the expected behavior should be: all CUDA API calls in CPU code are made as fast as possible, then the device can schedule kernels and mem-copies from different streams at the same time.
But now due to slow execution of CPU code, all p2p mem-copies are finished before computation kernels are submitted to the device which lead to no overlapping between p2p mem-copies and computation.

Can anybody help me?

@Fanoid
Copy link
Author

Fanoid commented Apr 26, 2024

I have two conjectures now:

  1. The CPU code is just slow, because there are some calls of tensor split and view.
  2. There is a global lock in the xformers, especially for the WriteValues, WaitValues, and Memset32bAsync operators. Because we have tensor parallel size = 8, there could be race between them?

@danthe3rd
Copy link
Contributor

Hi,
That's a valid point - fused sequence parallel has a higher CPU cost than non-fused operator, and this can impact performance. In general, in xFormers we don't won't to focus on optimizing CPU-time, as we can usually get rid of it entirely with CUDA Graphs - although it might not be trivial to implement.
cc @lw

@lw
Copy link
Contributor

lw commented Apr 26, 2024

Hi, thanks for your report! Indeed we are aware of the substantial CPU overhead of fused sequence parallel, although we've found that for large enough tensors this can usually be hidden. The actual "minimum" sizes depend on your hardware.

Note that indeed CUDA graphs could help here, however as of now fused sequence parallel isn't graphable yet. (We have some dynamic values that we pass to our kernels). I have some code that should make it graphable but there were still some bugs.

If you need to get unblocked quickly, I would either look at increasing your tensor sizes (e.g., by increasing the batch size) or at using Triton (as it replaces 8 kernel launches with a single one). Note that Triton has recently further improved their launch overhead, so you might benefit from using a nightly version.

@Fanoid
Copy link
Author

Fanoid commented Apr 27, 2024

Thanks for your kind replies, @danthe3rd and @lw .

@lw I've tried triton kernel. But in my case, the seq length is not fixed. Then for almost every batch of data, triton autotune is triggered, which is very very expensive (see my next issue #1037). I'm testing with larger batch size, yet it requires more device memory.

After i checked the implementation, I'm just confused why the simple Python code needs several miniseconds to execute.

BTW, I find the result of fused_ag benchmark is deceptive for llama_7b_MHA. CPU overhead of fused_ag is actually overlapped by the device execution of previous tests. If cuda sychronize is inserted before fused_ag, it becomes slower than the baseline.

@lw
Copy link
Contributor

lw commented May 21, 2024

Hey, sorry for the late answer, I was on vacation these past weeks.

If your sequence length is variable then the main workaround I can suggest is that you try to "pad" it in order to always have a consistent value (or a small set of possible values, e.g., the powers of 2). Padding could be done just before invoking the operator and, if done well, should be basically free, since the padding can just be left as uninitialized memory.

And I agree that the CPU overhead of these ops is very large, but in our design we chose on purpose to ignore it, as in real training workloads we always observed that the bottleneck was the GPU.

@Fanoid
Copy link
Author

Fanoid commented May 22, 2024

Yeah, I've already tried padding and is acceptable in our training.

@Fanoid Fanoid closed this as completed May 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants