Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Teach FakeTensor to fill in item_memo when converting scalar CPU tensor #126245

Closed
wants to merge 10 commits into from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented May 15, 2024

Stack from ghstack (oldest at bottom):

This PR requires a little justification, but let's start with what it does first:

  1. When you have a 0d CPU scalar int64/float64 tensor input to a graph, we will preallocate a backed SymInt/SymFloat corresponding to what you would get if you call item() on this tensor. This means you can freely change your input to be a Python int/float or a Tensor with an item() call and end up with exactly the same level of expressivity (specifically, you can guard on the internal SymInt/SymFloat no matter what). By default, the source of the backed SymInt/SymFloat is L['tensor'].item(), but if you have promoted a float input into a Tensor, we will cancel out torch.as_tensor(L['float']).item() into just L['float'].
  2. We switch wrap_symfloat to use this, instead of hand crafting the new SymNodeVariable. Everything works out, except that we carefully pass the item() result to tracked fakes (and not the fake Tensor argument)

OK, so why do this at all? There is some marginal benefit where now some item() calls on scalar inputs can be guarded on, but IMO this is a pretty marginal benefit, and if it was the only reason, I wouldn't do this. The real reason for this is that I need to be able to propagate fake tensors through the graphs that are produced by Dynamo, and if I am doing the old custom wrap_symfloat logic, there's no way I can do this, because ordinarily an item() call will cause an unbacked SymInt when I reallocate.

The other obvious way to solve the problem above is to make a HOP alternative that item() that "bakes in" the backed SymInt its supposed to return. But this strategy seems more parsimonious, and it does have the marginal benefit I mentioned above. The main downside is that what I have to do next, is make it so that when I run tensor computation, I also apply the equivalent operations to the SymInt/SymFloat as well. That's next PR.

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126245

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 1 Unrelated Failure

As of commit 1b408fe with merge base 5ea956a (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request May 15, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 8bd40f9d382dec545c363fd9bd03fa8700d085ff
Pull Request resolved: #126245
[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 15, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 4ab08d48328b4b53d950ee8c99840b91c43e372b
Pull Request resolved: #126245
[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 15, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 4e1b6b22c0ad34d6ef2b1dce4434a2c1c359ec2e
Pull Request resolved: #126245
[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 15, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 3e4ae98b6d32d18055eca4c03a2636d72eeb7ad9
Pull Request resolved: #126245
[ghstack-poisoned]
@ezyang ezyang added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label May 15, 2024
ezyang added a commit that referenced this pull request May 15, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: e4af100e8076f0668643b9a60427dae97b1a2459
Pull Request resolved: #126245
[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 15, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 67506c1d21f013509e3c98ec3f87ad9e5ed35561
Pull Request resolved: #126245
@albanD albanD removed their request for review May 15, 2024 19:15
@ezyang ezyang requested a review from shazqadeer May 16, 2024 23:40
@@ -853,6 +928,7 @@ def __init__(
allow_non_fake_inputs=False,
shape_env=None,
static_shapes=None,
export=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the export kwarg ? would be great to avoid if possible..

Copy link
Contributor Author

@ezyang ezyang May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I need to avoid doing this transformation during export is that when item_memo gets filled with a backed SymInt, it transforms deferred runtime asserts into guards. But while deferred runtime asserts are incorporated into the exported graph product, typically guards are ignored entirely.

Perhaps this is an export problem: maybe we should always generate runtime assertions for guards in the graph, which I believe would eliminate the need for this. cc @angelayi @BoyuanFeng

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, transforming the guards here into runtime asserts makes sense to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has some knock on effects, specifically, we will start generating runtime asserts for size guards that we didn't previously generate. I could split the difference in this PR by only generating runtime asserts for tensor inputs :think:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent an honest attempt at this. I think it... could still possibly work? But it will be a bit invasive to export. I'm not keen on shoving it all into this patch. Here is my attempt: https://gist.github.com/ezyang/ad0f6da568b2c2b5c9427e0a5de53bde

The mechanical problem I'd gotten to before I gave up is that strict export is dropping the torch._check calls I'm inserting to enforce guards. Strict doesn't have this problem, because they manually call insert_deferred_runtime_asserts sufficiently late so the checks don't get removed. (AOTAutograd removes the torch._check calls in strict export, so we need to make sure we reinsert the deferred runtime asserts again.)

There's also some business going on in export where they are still generating asserts as assert_async on tensors (which obviously isn't going to propagate facts to ShapeEnv. I'm not entirely sure why the code is still doing it this way, this may need to get refactored.

At minimum these tests need fixing: https://gist.github.com/ezyang/d567607b415a9e76b3df038c447c6be1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed it at the export weekly and it looks like the export team has ongoing refactoring related to deferred runtime asserts that should potentially make it easier. So I propose that we land this PR as is, and then revisit export flag removal when those refactors land.

cc @pianpwk @avikchaudhuri @angelayi @tugsbayasgalan

torch/_subclasses/fake_tensor.py Outdated Show resolved Hide resolved
torch/_subclasses/fake_tensor.py Outdated Show resolved Hide resolved
torch/_subclasses/fake_tensor.py Show resolved Hide resolved
Comment on lines +375 to +378
if isinstance(source, FloatTensorSource):
item_source = source.base
else:
item_source = CallMethodItemSource(source)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is getting the weeds to dynamo in a way that feels out of place for fake tensor.. not sure if theres a better solution or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there's just a convenient way to ask Dynamo to swizzle the source before it gets to FakeTensor. There are already some references to sources in fake tensor mode already, so it doesn't seem too much worse.

OnlyFor pushed a commit to OnlyFor/pytorch that referenced this pull request May 17, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 67506c1d21f013509e3c98ec3f87ad9e5ed35561
Pull Request resolved: pytorch#126245
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 21, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: e89018ceaf9508ba42bb9949228f5e43eac8a1de
Pull Request resolved: #126245
@ezyang
Copy link
Contributor Author

ezyang commented May 21, 2024

I addressed all of the CR comments, can we agree to land this now, or do we want to ice this until the export refactors are done?

@ezyang
Copy link
Contributor Author

ezyang commented May 21, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 21, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@ezyang
Copy link
Contributor Author

ezyang commented May 22, 2024

@pytorchbot merge -f "unrelated to this pr problesm"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: fx release notes category suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants