Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor][layout optimization] Extra recompilations because of a guard failure on weight #126241

Closed
anijain2305 opened this issue May 14, 2024 · 8 comments
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@anijain2305
Copy link
Contributor

anijain2305 commented May 14, 2024

馃悰 Describe the bug

TORCH_LOGS="recompiles" TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --backend=inductor --device cuda --inference --bfloat16 --only=moco

The recompilation causing problem is

[rank0]:V0514 16:27:34.740000 140652138796864 torch/_dynamo/guards.py:2594] [2/1] [__recompiles] Recompiling function forward in /data/users/anijain/torchbenchmark/torchbenchmark/models/moco/moco/builder.py:115
[rank0]:V0514 16:27:34.740000 140652138796864 torch/_dynamo/guards.py:2594] [2/1] [__recompiles]     triggered by the following guard failure(s):
[rank0]:V0514 16:27:34.740000 140652138796864 torch/_dynamo/guards.py:2594] [2/1] [__recompiles]     - tensor 'L['self'].encoder_q.layer1[0].conv1.weight' stride mismatch at index 2. expected 1, actual 64

How do we know this recompilation is a problem? There are a couple of reasons

  1. Above recompilation does not happen with backend=eager
  2. Above recompilation does not happen with TORCHINDUCTOR_LAYOUT_OPTIMIZATION=0

Similar recompilations happen for model - detectron2_fcos_r_50_fpn

Error logs

No response

Minified repro

No response

Versions

N/A

cc @ezyang @msaroufim @bdhirsh @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@anijain2305 anijain2305 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: inductor labels May 14, 2024
@shunting314
Copy link
Contributor

@anijain2305

for log:

[rank0]:V0517 17:42:09.269000 140448941286400 torch/_dynamo/guards.py:2598] [2/1] [__recompiles] Recompiling function forward in /home/shunting/ws/pytorch/torchbenchmark/torchbenchmark/models/moco/moco/builder.py:115
[rank0]:V0517 17:42:09.269000 140448941286400 torch/_dynamo/guards.py:2598] [2/1] [__recompiles]     triggered by the following guard failure(s):
[rank0]:V0517 17:42:09.269000 140448941286400 torch/_dynamo/guards.py:2598] [2/1] [__recompiles]     - tensor 'L['self'].encoder_q.layer1[0].conv1.weight' stride mismatch at index 2. expected 1, actual 64

How to print the whole expected and actual strides?

@shunting314
Copy link
Contributor

shunting314 commented May 18, 2024

ok, I printed the stride from model directly, and see the stride/shape in eager mode:

(Pdb) self.encoder_q.layer1[0].conv1.weight.shape
torch.Size([64, 64, 1, 1])
(Pdb) self.encoder_q.layer1[0].conv1.weight.stride()
(64, 1, 1, 1)

you can see that the mismatch dimension has a size == 1. Can we not check stride if size == 1, since any stride is equivalent in this case. @anijain2305

@eellison
Copy link
Contributor

hmm, i wonder if this is same reason for recompilation of #125641

@anijain2305
Copy link
Contributor Author

anijain2305 commented May 19, 2024

You can use TORCH_LOGS="guards, recompiles" to print the guards as well. You will have to search to find the relevant TENSOR_MATCH guards.

These are the relevant lines

 +- TENSOR_MATCH: check_tensor(L['self'].encoder_q.layer1._modules['0'].conv1.weight, Parameter, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.bfloat16, device=0, requires_grad=True, size=[64, 64, 1, 1], stride=[64, 1, 1, 1])  # return self._conv_forward(input, self.weight, self.bias)  # nn/modules/conv.py:460 in forward
 - tensor 'L['self'].encoder_q.layer1._modules['0'].conv1.weight' stride mismatch at index 2. expected 1, actual 64
 +- TENSOR_MATCH: check_tensor(L['self'].encoder_q.layer1._modules['0'].conv1.weight, Parameter, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.bfloat16, device=0, requires_grad=True, size=[64, 64, 1, 1], stride=[64, 1, 64, 64])  # return self._conv_forward(input, self.weight, self.bias)  # nn/modules/conv.py:460 in forward

At the beginning, we have size=[64, 64, 1, 1], stride=[64, 1, 1, 1])
After the first compilation, we have size=[64, 64, 1, 1], stride=[64, 1, 64, 64]

So, the main question is - Is this expected? Does layout transformation mutates the weight layout in place?

cc @shunting314

@shunting314
Copy link
Contributor

So, the main question is - Is this expected? Does layout transformation mutates the weight layout in place?

No.

And the layout of weight turns out to be changed by torch.convolution.

Check:

import torch
from torch import nn
import torch._inductor.config as inductor_config
import torch._dynamo.config as dynamo_config

inductor_config.force_layout_optimization = True
dynamo_config.inline_inbuilt_nn_modules = True

torch.set_default_device("cuda")
torch.set_default_dtype(torch.bfloat16)

conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
conv2 = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
x = torch.randn(4, 3, 224, 224)

@torch.compile
def f(x):
    x = conv1(x)
    x = conv2(x)
    return x

for i in range(2):
    f(x)
print("bye")

In this simple model with 2 convolutions, I observe the same symptom as moco. And then check the following pytorch code:

import torch

torch.set_default_device("cuda")

weight = torch.randn(64, 64, 1, 1)
x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last)
print(f"x stride {x.stride()=}")
print(f"before: {weight.stride()=}")
torch.convolution(x, weight, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
print(f"after: {weight.stride()=}")

The output is

x stride x.stride()=(6400, 1, 640, 64)
before: weight.stride()=(64, 1, 1, 1)
after: weight.stride()=(64, 1, 64, 64)

The torch.convolution call changes the layout of the weight. I guess it does so since it does not hurt. But I think dynamo need to make sure we don't fail the guard when this happens.

@mlazos
Copy link
Contributor

mlazos commented May 21, 2024

#126585 is possibly a related issue

@ezyang
Copy link
Contributor

ezyang commented May 21, 2024

We should fix eager convolution not to do this. It should not be hard to find.

@eellison
Copy link
Contributor

it鈥檚 coming from fixSizeOneDimStride call inside Conv_v8.

eellison added a commit that referenced this issue May 21, 2024
Fix for #126241. 

Within the cudnn convolution, we were in-place updating the strides of the tensor to disambiguate for size-1 dims and contiguous and channels last tensors. Instead of mutating the tensors stride, just use a temporary. Inside cudnn it is then copied: https://github.com/NVIDIA/cudnn-frontend/blob/d7ccb5b3c47b4de709604cce463ad66b775b7812/include/cudnn_frontend_Tensor.h#L201-L203.



[ghstack-poisoned]
eellison added a commit that referenced this issue May 21, 2024
Fix for #126241. 

Within the cudnn convolution, we were in-place updating the strides of the tensor to disambiguate for size-1 dims and contiguous and channels last tensors. Instead of mutating the tensors stride, just use a temporary. Inside cudnn it is then copied: https://github.com/NVIDIA/cudnn-frontend/blob/d7ccb5b3c47b4de709604cce463ad66b775b7812/include/cudnn_frontend_Tensor.h#L201-L203.



[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue May 22, 2024
Fix for #126241.

Within the cudnn convolution, we were in-place updating the strides of the tensor to disambiguate for size-1 dims and contiguous and channels last tensors. Instead of mutating the tensors stride, just use a temporary. Inside cudnn it is then copied: https://github.com/NVIDIA/cudnn-frontend/blob/d7ccb5b3c47b4de709604cce463ad66b775b7812/include/cudnn_frontend_Tensor.h#L201-L203.

Pull Request resolved: #126786
Approved by: https://github.com/ezyang, https://github.com/shunting314, https://github.com/eqy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants