Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Tensorclass.key() doesn't list non-tensor data. #717

Open
3 tasks done
maximilianigl opened this issue Mar 21, 2024 · 2 comments
Open
3 tasks done

[BUG] Tensorclass.key() doesn't list non-tensor data. #717

maximilianigl opened this issue Mar 21, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@maximilianigl
Copy link

Describe the bug

Tensorclasses (in contrast to tensordicts) don't list non-tensor data when iterating through it with .keys() or .items(). This also affects e.g. apply().

To Reproduce

from tensordict import TensorDict
from tensordict.prototype import tensorclass
import torch

tensordict = TensorDict({"tensor": torch.ones(3), "string": "string"})

print(sorted(tensordict.keys()))
# Output: ['string', 'tensor']

@tensorclass
class MyTensorClass:
    tensor: torch.Tensor
    string: str


my_tensor_class = MyTensorClass(tensor=torch.ones(3), string="string")

print(sorted(my_tensor_class.keys()))
# Output: ['tensor'], i.e. it ignores non-tensor data.

print(my_tensor_class.apply(lambda x: print(x)))
# Output: tensor([1., 1., 1.])

Expected behavior

For keys() and items() I'd expect it to iterate over non-tensor data.
For apply() probably as well, but whether or not it iterates over non-tensor data could also be dependent on an input flag (as one might explicitly want to only apply something to tensors).

Screenshots

If applicable, add screenshots to help explain your problem.

System info

tensordict version: '0.4.0+b4c91e8'

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@maximilianigl maximilianigl added the bug Something isn't working label Mar 21, 2024
@vmoens
Copy link
Contributor

vmoens commented Mar 26, 2024

This used to be a design choice before we introduced NonTensorData

Context

We initially thought it could be interesting to let tensorclass carry non-tensor data, but thought it was better to exclude it from the keys since things like apply or even any other op that iterates through the keys (reshape, gather, ...) would have been meaningless with non-tensor data.

Then we introduces the NonTensorData which is a simple subclass of the a tensorclass that can only carry non-tensor data. If you call tensordict.keys(include_nested=True) you will have a NonTensorData node appearing as if it had no leaves, but in reality it is a leaf. If you call apply over NonTensorData it's ok because it does not access the data field (it's not part of the keys). But becaue we hack through __getitem__ and __setitem__ to access NonTensorData.data, the situation is less clear (one would imagine that the non-tensor data is now part of the keys since we use the key to set or access the value).

Solution

We introduced a is_leaf function in keys, items, values and apply to quickly check if a node is a leaf or not. That allows us to avoid things like reshape to access the non-tensor data thinking it's a leaf while still presenting non-tensor data as leaves to the user.
We should now apply something similar with all tensorclasses: by default, keys will return all the data (tensor and non-tensor) but internally operations will only be applied to tensor-data.

@maximilianigl
Copy link
Author

Sounds good to me!
If I understand it correctly then 'real' leafs are tensors and non-tensors are not leafs because they're wrapped in NonTensorData, which is a dataclass-subclass?
As a user, I'd find that slightly confusing, since to me non-tensor data does appear as leafs. But I guess the is_leaf is only an internal implementation details?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants