-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor][layout optimization] Extra recompilations because of a guard failure on weight #126241
Comments
for log:
How to print the whole expected and actual strides? |
ok, I printed the stride from model directly, and see the stride/shape in eager mode:
you can see that the mismatch dimension has a size == 1. Can we not check stride if size == 1, since any stride is equivalent in this case. @anijain2305 |
hmm, i wonder if this is same reason for recompilation of #125641 |
You can use TORCH_LOGS="guards, recompiles" to print the guards as well. You will have to search to find the relevant TENSOR_MATCH guards. These are the relevant lines
At the beginning, we have So, the main question is - Is this expected? Does layout transformation mutates the weight layout in place? cc @shunting314 |
No. And the layout of weight turns out to be changed by torch.convolution. Check:
In this simple model with 2 convolutions, I observe the same symptom as moco. And then check the following pytorch code:
The output is
The torch.convolution call changes the layout of the weight. I guess it does so since it does not hurt. But I think dynamo need to make sure we don't fail the guard when this happens. |
#126585 is possibly a related issue |
We should fix eager convolution not to do this. It should not be hard to find. |
it鈥檚 coming from fixSizeOneDimStride call inside Conv_v8. |
Fix for #126241. Within the cudnn convolution, we were in-place updating the strides of the tensor to disambiguate for size-1 dims and contiguous and channels last tensors. Instead of mutating the tensors stride, just use a temporary. Inside cudnn it is then copied: https://github.com/NVIDIA/cudnn-frontend/blob/d7ccb5b3c47b4de709604cce463ad66b775b7812/include/cudnn_frontend_Tensor.h#L201-L203. [ghstack-poisoned]
Fix for #126241. Within the cudnn convolution, we were in-place updating the strides of the tensor to disambiguate for size-1 dims and contiguous and channels last tensors. Instead of mutating the tensors stride, just use a temporary. Inside cudnn it is then copied: https://github.com/NVIDIA/cudnn-frontend/blob/d7ccb5b3c47b4de709604cce463ad66b775b7812/include/cudnn_frontend_Tensor.h#L201-L203. [ghstack-poisoned]
Fix for #126241. Within the cudnn convolution, we were in-place updating the strides of the tensor to disambiguate for size-1 dims and contiguous and channels last tensors. Instead of mutating the tensors stride, just use a temporary. Inside cudnn it is then copied: https://github.com/NVIDIA/cudnn-frontend/blob/d7ccb5b3c47b4de709604cce463ad66b775b7812/include/cudnn_frontend_Tensor.h#L201-L203. Pull Request resolved: #126786 Approved by: https://github.com/ezyang, https://github.com/shunting314, https://github.com/eqy
馃悰 Describe the bug
The recompilation causing problem is
How do we know this recompilation is a problem? There are a couple of reasons
TORCHINDUCTOR_LAYOUT_OPTIMIZATION=0
Similar recompilations happen for model -
detectron2_fcos_r_50_fpn
Error logs
No response
Minified repro
No response
Versions
N/A
cc @ezyang @msaroufim @bdhirsh @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: